.. _sphx_glr_auto_examples_plot_feature_importance.py: =========================== Plotting feature importance =========================== A simple example showing how to compute and display feature importances, it is also compared with the feature importances obtained using random forests. Feature importance is a measure of the effect of the features on the outputs. For each feature, the values go from 0 to 1 where a higher the value means that the feature will have a higher effect on the outputs. Currently three criteria are supported : 'gcv', 'rss' and 'nb_subsets'. See [1], section 12.3 for more information about the criteria. .. [1] http://www.milbo.org/doc/earth-notes.pdf .. image:: /auto_examples/images/sphx_glr_plot_feature_importance_001.png :align: center .. rst-class:: sphx-glr-script-out Out:: Beginning forward pass ----------------------------------------------------------------- iter parent var knot mse terms gcv rsq grsq ----------------------------------------------------------------- 0 - - - 24.101072 1 24.106 0.000 0.000 1 0 3 -1 15.626212 2 15.637 0.352 0.351 2 0 1 2446 10.884011 4 10.903 0.548 0.548 3 0 0 43 6.386236 6 6.404 0.735 0.734 4 0 2 3272 4.226153 8 4.242 0.825 0.824 5 0 4 -1 2.136919 9 2.146 0.911 0.911 6 4 1 9815 1.113502 11 1.119 0.954 0.954 --------------------------------------------------------------- Stopping Condition 0: Reached maximum number of terms Beginning pruning pass -------------------------------------------- iter bf terms mse gcv rsq grsq -------------------------------------------- 0 - 11 1.11 1.119 0.954 0.954 1 2 10 1.40 1.405 0.942 0.942 2 10 9 1.63 1.635 0.932 0.932 3 4 8 1.68 1.684 0.930 0.930 4 9 7 2.16 2.168 0.910 0.910 5 6 6 3.84 3.851 0.841 0.840 6 7 5 4.29 4.297 0.822 0.822 7 8 4 6.41 6.425 0.734 0.733 8 5 3 10.89 10.907 0.548 0.548 9 3 2 15.63 15.637 0.352 0.351 10 1 1 24.10 24.106 -0.000 -0.000 ------------------------------------------------ Selected iteration: 0 Forward Pass ----------------------------------------------------------------- iter parent var knot mse terms gcv rsq grsq ----------------------------------------------------------------- 0 - - - 24.101072 1 24.106 0.000 0.000 1 0 3 -1 15.626212 2 15.637 0.352 0.351 2 0 1 2446 10.884011 4 10.903 0.548 0.548 3 0 0 43 6.386236 6 6.404 0.735 0.734 4 0 2 3272 4.226153 8 4.242 0.825 0.824 5 0 4 -1 2.136919 9 2.146 0.911 0.911 6 4 1 9815 1.113502 11 1.119 0.954 0.954 ----------------------------------------------------------------- Stopping Condition 0: Reached maximum number of terms Pruning Pass ------------------------------------------------ iter bf terms mse gcv rsq grsq ------------------------------------------------ 0 - 11 1.11 1.119 0.954 0.954 1 2 10 1.40 1.405 0.942 0.942 2 10 9 1.63 1.635 0.932 0.932 3 4 8 1.68 1.684 0.930 0.930 4 9 7 2.16 2.168 0.910 0.910 5 6 6 3.84 3.851 0.841 0.840 6 7 5 4.29 4.297 0.822 0.822 7 8 4 6.41 6.425 0.734 0.733 8 5 3 10.89 10.907 0.548 0.548 9 3 2 15.63 15.637 0.352 0.351 10 1 1 24.10 24.106 -0.000 -0.000 ------------------------------------------------ Selected iteration: 0 Earth Model ---------------------------------------------------- Basis Function Pruned Coefficient ---------------------------------------------------- (Intercept) No 7.68887 x3 No 10.0005 h(x1-0.525027) No 5.34483 h(0.525027-x1) No -9.90971 h(x0-0.563426) No 10.4923 h(0.563426-x0) No -12.0495 h(x2-0.496723) No 9.90482 h(0.496723-x2) No 10.1095 x4 No 5.00753 h(x1-0.496528)*h(x0-0.563426) No -56.5856 h(0.496528-x1)*h(x0-0.563426) No -36.1381 ---------------------------------------------------- MSE: 1.1135, GCV: 1.1193, RSQ: 0.9538, GRSQ: 0.9536 nb_subsets gcv rss x3 0.09 0.36 0.36 x1 0.27 0.23 0.23 x0 0.36 0.22 0.22 x2 0.18 0.09 0.09 x4 0.09 0.09 0.09 x9 0.00 0.00 0.00 x8 0.00 0.00 0.00 x7 0.00 0.00 0.00 x6 0.00 0.00 0.00 x5 0.00 0.00 0.00 | .. code-block:: python import numpy import matplotlib.pyplot as plt from sklearn.ensemble import RandomForestRegressor from pyearth import Earth # Create some fake data numpy.random.seed(2) m = 10000 n = 10 X = numpy.random.uniform(size=(m, n)) y = (10 * numpy.sin(numpy.pi * X[:, 0] * X[:, 1]) + 20 * (X[:, 2] - 0.5) ** 2 + 10 * X[:, 3] + 5 * X[:, 4] + numpy.random.uniform(size=m)) # Fit an Earth model criteria = ('rss', 'gcv', 'nb_subsets') model = Earth(max_degree=3, max_terms=10, minspan_alpha=.5, feature_importance_type=criteria, verbose=True) model.fit(X, y) rf = RandomForestRegressor() rf.fit(X, y) # Print the model print(model.trace()) print(model.summary()) print(model.summary_feature_importances(sort_by='gcv')) # Plot the feature importances importances = model.feature_importances_ importances['random_forest'] = rf.feature_importances_ criteria = criteria + ('random_forest',) idx = 1 fig = plt.figure(figsize=(20, 10)) labels = ['$x_{}$'.format(i) for i in range(n)] for crit in criteria: plt.subplot(2, 2, idx) plt.bar(numpy.arange(len(labels)), importances[crit], align='center', color='red') plt.xticks(numpy.arange(len(labels)), labels) plt.title(crit) plt.ylabel('importances') idx += 1 title = '$x_0,...x_9 \sim \mathcal{N}(0, 1)$\n$y= 10sin(\pi x_{0}x_{1}) + 20(x_2 - 0.5)^2 + 10x_3 + 5x_4 + Unif(0, 1)$' fig.suptitle(title, fontsize="x-large") plt.show() **Total running time of the script:** (0 minutes 1.583 seconds) .. container:: sphx-glr-download **Download Python source code:** :download:`plot_feature_importance.py ` .. container:: sphx-glr-download **Download IPython notebook:** :download:`plot_feature_importance.ipynb `