Plotting feature importanceΒΆ

A simple example showing how to compute and display feature importances, it is also compared with the feature importances obtained using random forests.

Feature importance is a measure of the effect of the features on the outputs. For each feature, the values go from 0 to 1 where a higher the value means that the feature will have a higher effect on the outputs.

Currently three criteria are supported : ‘gcv’, ‘rss’ and ‘nb_subsets’. See [1], section 12.3 for more information about the criteria.

[1]http://www.milbo.org/doc/earth-notes.pdf
../_images/sphx_glr_plot_feature_importance_001.png

Out:

  Beginning forward pass
-----------------------------------------------------------------
iter  parent  var  knot  mse        terms  gcv     rsq    grsq
-----------------------------------------------------------------
0     -       -    -     24.101072  1      24.106  0.000  0.000
1     0       3    -1    15.626212  2      15.637  0.352  0.351
2     0       1    2446  10.884011  4      10.903  0.548  0.548
3     0       0    43    6.386236  6      6.404  0.735  0.734
4     0       2    3272  4.226153  8      4.242  0.825  0.824
5     0       4    -1    2.136919  9      2.146  0.911  0.911
6     4       1    9815  1.113502  11     1.119  0.954  0.954
---------------------------------------------------------------
Stopping Condition 0: Reached maximum number of terms
Beginning pruning pass
--------------------------------------------
iter  bf  terms  mse   gcv    rsq    grsq
--------------------------------------------
0     -   11     1.11  1.119  0.954  0.954
1     2   10     1.40  1.405  0.942  0.942
2     10  9      1.63  1.635  0.932  0.932
3     4   8      1.68  1.684  0.930  0.930
4     9   7      2.16  2.168  0.910  0.910
5     6   6      3.84  3.851  0.841  0.840
6     7   5      4.29  4.297  0.822  0.822
7     8   4      6.41  6.425  0.734  0.733
8     5   3      10.89  10.907  0.548  0.548
9     3   2      15.63  15.637  0.352  0.351
10    1   1      24.10  24.106  -0.000  -0.000
------------------------------------------------
Selected iteration: 0
Forward Pass
-----------------------------------------------------------------
iter  parent  var  knot  mse        terms  gcv     rsq    grsq
-----------------------------------------------------------------
0     -       -    -     24.101072  1      24.106  0.000  0.000
1     0       3    -1    15.626212  2      15.637  0.352  0.351
2     0       1    2446  10.884011  4      10.903  0.548  0.548
3     0       0    43    6.386236   6      6.404   0.735  0.734
4     0       2    3272  4.226153   8      4.242   0.825  0.824
5     0       4    -1    2.136919   9      2.146   0.911  0.911
6     4       1    9815  1.113502   11     1.119   0.954  0.954
-----------------------------------------------------------------
Stopping Condition 0: Reached maximum number of terms

Pruning Pass
------------------------------------------------
iter  bf  terms  mse    gcv     rsq     grsq
------------------------------------------------
0     -   11     1.11   1.119   0.954   0.954
1     2   10     1.40   1.405   0.942   0.942
2     10  9      1.63   1.635   0.932   0.932
3     4   8      1.68   1.684   0.930   0.930
4     9   7      2.16   2.168   0.910   0.910
5     6   6      3.84   3.851   0.841   0.840
6     7   5      4.29   4.297   0.822   0.822
7     8   4      6.41   6.425   0.734   0.733
8     5   3      10.89  10.907  0.548   0.548
9     3   2      15.63  15.637  0.352   0.351
10    1   1      24.10  24.106  -0.000  -0.000
------------------------------------------------
Selected iteration: 0

Earth Model
----------------------------------------------------
Basis Function                 Pruned  Coefficient
----------------------------------------------------
(Intercept)                    No      7.68887
x3                             No      10.0005
h(x1-0.525027)                 No      5.34483
h(0.525027-x1)                 No      -9.90971
h(x0-0.563426)                 No      10.4923
h(0.563426-x0)                 No      -12.0495
h(x2-0.496723)                 No      9.90482
h(0.496723-x2)                 No      10.1095
x4                             No      5.00753
h(x1-0.496528)*h(x0-0.563426)  No      -56.5856
h(0.496528-x1)*h(x0-0.563426)  No      -36.1381
----------------------------------------------------
MSE: 1.1135, GCV: 1.1193, RSQ: 0.9538, GRSQ: 0.9536
       nb_subsets    gcv    rss
x3     0.09          0.36   0.36
x1     0.27          0.23   0.23
x0     0.36          0.22   0.22
x2     0.18          0.09   0.09
x4     0.09          0.09   0.09
x9     0.00          0.00   0.00
x8     0.00          0.00   0.00
x7     0.00          0.00   0.00
x6     0.00          0.00   0.00
x5     0.00          0.00   0.00

import numpy
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestRegressor
from pyearth import Earth

# Create some fake data
numpy.random.seed(2)
m = 10000
n = 10

X = numpy.random.uniform(size=(m, n))
y = (10 * numpy.sin(numpy.pi * X[:, 0] * X[:, 1]) +
     20 * (X[:, 2] - 0.5) ** 2 +
     10 * X[:, 3] +
     5 * X[:, 4] + numpy.random.uniform(size=m))
# Fit an Earth model
criteria = ('rss', 'gcv', 'nb_subsets')
model = Earth(max_degree=3,
              max_terms=10,
              minspan_alpha=.5,
              feature_importance_type=criteria,
              verbose=True)
model.fit(X, y)
rf = RandomForestRegressor()
rf.fit(X, y)
# Print the model
print(model.trace())
print(model.summary())
print(model.summary_feature_importances(sort_by='gcv'))

# Plot the feature importances
importances = model.feature_importances_
importances['random_forest'] = rf.feature_importances_
criteria = criteria + ('random_forest',)
idx = 1

fig = plt.figure(figsize=(20, 10))
labels = ['$x_{}$'.format(i) for i in range(n)]
for crit in criteria:
    plt.subplot(2, 2, idx)
    plt.bar(numpy.arange(len(labels)),
            importances[crit],
            align='center',
            color='red')
    plt.xticks(numpy.arange(len(labels)), labels)
    plt.title(crit)
    plt.ylabel('importances')
    idx += 1
title = '$x_0,...x_9 \sim \mathcal{N}(0, 1)$\n$y= 10sin(\pi x_{0}x_{1}) + 20(x_2 - 0.5)^2 + 10x_3 + 5x_4 + Unif(0, 1)$'
fig.suptitle(title, fontsize="x-large")
plt.show()

Total running time of the script: (0 minutes 1.583 seconds)

Download Python source code: plot_feature_importance.py
Download IPython notebook: plot_feature_importance.ipynb