Show U-curve of regularization#

Illustrate the sweet spot of regularization: not too much, not too little. We showcase that for the Lasso estimator on the rcv1.binary dataset.

import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from libsvmdata import fetch_libsvm

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from skglm import Lasso

First, we load the dataset and keep 2000 features. We also retrain 2000 samples in training dataset.

X, y = fetch_libsvm("rcv1.binary")

X = X[:, :2000]
X_train, X_test, y_train, y_test = train_test_split(X, y)
X_train, y_train = X_train[:2000], y_train[:2000]
file_sizes:   0%|                                   | 0.00/13.7M [00:00<?, ?B/s]
file_sizes:   0%|                          | 24.6k/13.7M [00:00<02:24, 94.7kB/s]
file_sizes:   0%|                          | 49.2k/13.7M [00:00<02:23, 95.3kB/s]
file_sizes:   1%|▏                           | 106k/13.7M [00:00<01:29, 152kB/s]
file_sizes:   2%|▍                           | 221k/13.7M [00:01<01:11, 189kB/s]
file_sizes:   3%|▉                           | 451k/13.7M [00:01<00:34, 386kB/s]
file_sizes:   4%|█                           | 516k/13.7M [00:01<00:38, 343kB/s]
file_sizes:   4%|█▏                          | 582k/13.7M [00:02<00:42, 311kB/s]
file_sizes:   5%|█▍                          | 696k/13.7M [00:02<00:37, 346kB/s]
file_sizes:   6%|█▌                          | 795k/13.7M [00:02<00:36, 354kB/s]
file_sizes:   7%|█▊                          | 893k/13.7M [00:02<00:35, 361kB/s]
file_sizes:   7%|██                          | 991k/13.7M [00:03<00:34, 370kB/s]
file_sizes:   8%|██▎                        | 1.16M/13.7M [00:03<00:29, 423kB/s]
file_sizes:   9%|██▍                        | 1.25M/13.7M [00:03<00:29, 430kB/s]
file_sizes:  10%|██▌                        | 1.32M/13.7M [00:03<00:32, 378kB/s]
file_sizes:  10%|██▊                        | 1.43M/13.7M [00:04<00:31, 392kB/s]
file_sizes:  12%|███▏                       | 1.60M/13.7M [00:04<00:26, 456kB/s]
file_sizes:  12%|███▎                       | 1.66M/13.7M [00:04<00:29, 402kB/s]
file_sizes:  13%|███▍                       | 1.78M/13.7M [00:04<00:28, 424kB/s]
file_sizes:  14%|███▊                       | 1.94M/13.7M [00:05<00:25, 468kB/s]
file_sizes:  15%|████                       | 2.04M/13.7M [00:05<00:25, 462kB/s]
file_sizes:  16%|████▎                      | 2.17M/13.7M [00:05<00:25, 462kB/s]
file_sizes:  17%|████▍                      | 2.27M/13.7M [00:05<00:24, 466kB/s]
file_sizes:  17%|████▋                      | 2.40M/13.7M [00:06<00:23, 491kB/s]
file_sizes:  18%|████▉                      | 2.53M/13.7M [00:06<00:23, 482kB/s]
file_sizes:  19%|█████▏                     | 2.63M/13.7M [00:06<00:23, 469kB/s]
file_sizes:  20%|█████▍                     | 2.76M/13.7M [00:06<00:22, 484kB/s]
file_sizes:  21%|█████▋                     | 2.89M/13.7M [00:07<00:21, 503kB/s]
file_sizes:  22%|█████▉                     | 3.02M/13.7M [00:07<00:20, 518kB/s]
file_sizes:  23%|██████▏                    | 3.15M/13.7M [00:07<00:20, 520kB/s]
file_sizes:  24%|██████▍                    | 3.28M/13.7M [00:07<00:19, 522kB/s]
file_sizes:  25%|██████▋                    | 3.42M/13.7M [00:08<00:19, 525kB/s]
file_sizes:  26%|██████▉                    | 3.55M/13.7M [00:08<00:19, 528kB/s]
file_sizes:  27%|███████▎                   | 3.71M/13.7M [00:08<00:17, 573kB/s]
file_sizes:  28%|███████▌                   | 3.87M/13.7M [00:08<00:16, 587kB/s]
file_sizes:  29%|███████▉                   | 4.04M/13.7M [00:09<00:16, 600kB/s]
file_sizes:  31%|████████▎                  | 4.20M/13.7M [00:09<00:15, 612kB/s]
file_sizes:  32%|████████▋                  | 4.43M/13.7M [00:09<00:13, 699kB/s]
file_sizes:  34%|█████████▏                 | 4.69M/13.7M [00:09<00:11, 785kB/s]
file_sizes:  36%|█████████▌                 | 4.89M/13.7M [00:10<00:11, 786kB/s]
file_sizes:  37%|██████████                 | 5.09M/13.7M [00:10<00:10, 799kB/s]
file_sizes:  39%|██████████▌                | 5.35M/13.7M [00:10<00:09, 872kB/s]
file_sizes:  41%|███████████                | 5.61M/13.7M [00:10<00:08, 917kB/s]
file_sizes:  43%|███████████▌               | 5.87M/13.7M [00:11<00:08, 941kB/s]
file_sizes:  45%|████████████               | 6.14M/13.7M [00:11<00:07, 952kB/s]
file_sizes:  47%|████████████▏             | 6.46M/13.7M [00:11<00:07, 1.02MB/s]
file_sizes:  49%|█████████████              | 6.66M/13.7M [00:12<00:07, 906kB/s]
file_sizes:  52%|█████████████▍            | 7.12M/13.7M [00:12<00:05, 1.19MB/s]
file_sizes:  54%|█████████████▉            | 7.38M/13.7M [00:12<00:05, 1.15MB/s]
file_sizes:  56%|██████████████▌           | 7.71M/13.7M [00:12<00:05, 1.20MB/s]
file_sizes:  58%|███████████████           | 7.97M/13.7M [00:13<00:04, 1.16MB/s]
file_sizes:  60%|███████████████▌          | 8.23M/13.7M [00:13<00:04, 1.12MB/s]
file_sizes:  62%|████████████████▏         | 8.56M/13.7M [00:13<00:04, 1.18MB/s]
file_sizes:  65%|████████████████▊         | 8.89M/13.7M [00:13<00:03, 1.22MB/s]
file_sizes:  67%|█████████████████▎        | 9.15M/13.7M [00:13<00:03, 1.36MB/s]
file_sizes:  68%|█████████████████▋        | 9.31M/13.7M [00:14<00:03, 1.30MB/s]
file_sizes:  69%|██████████████████        | 9.51M/13.7M [00:14<00:03, 1.14MB/s]
file_sizes:  72%|██████████████████▋       | 9.84M/13.7M [00:14<00:02, 1.44MB/s]
file_sizes:  73%|██████████████████▉       | 10.0M/13.7M [00:14<00:02, 1.35MB/s]
file_sizes:  74%|███████████████████▎      | 10.2M/13.7M [00:14<00:03, 1.17MB/s]
file_sizes:  77%|███████████████████▉      | 10.5M/13.7M [00:14<00:02, 1.53MB/s]
file_sizes:  78%|████████████████████▎     | 10.7M/13.7M [00:15<00:02, 1.45MB/s]
file_sizes:  80%|████████████████████▊     | 11.0M/13.7M [00:15<00:02, 1.37MB/s]
file_sizes:  82%|█████████████████████▎    | 11.2M/13.7M [00:15<00:01, 1.58MB/s]
file_sizes:  83%|█████████████████████▋    | 11.4M/13.7M [00:15<00:01, 1.51MB/s]
file_sizes:  85%|██████████████████████▏   | 11.7M/13.7M [00:15<00:01, 1.40MB/s]
file_sizes:  87%|██████████████████████▋   | 12.0M/13.7M [00:15<00:01, 1.60MB/s]
file_sizes:  89%|███████████████████████   | 12.2M/13.7M [00:16<00:01, 1.47MB/s]
file_sizes:  91%|████████████████████████▍  | 12.4M/13.7M [00:16<00:01, 893kB/s]
file_sizes:  92%|███████████████████████▉  | 12.6M/13.7M [00:16<00:01, 1.01MB/s]
file_sizes:  95%|████████████████████████▊ | 13.1M/13.7M [00:16<00:00, 1.27MB/s]
file_sizes:  97%|█████████████████████████▏| 13.3M/13.7M [00:17<00:00, 1.33MB/s]
file_sizes:  98%|█████████████████████████▌| 13.5M/13.7M [00:17<00:00, 1.15MB/s]
file_sizes:  99%|█████████████████████████▊| 13.6M/13.7M [00:17<00:00, 1.15MB/s]
file_sizes: 100%|███████████████████████████| 13.7M/13.7M [00:17<00:00, 786kB/s]

Next, we define the regularization path. For Lasso, it is well know that there is an alpha_max above which the optimal solution is the zero vector.

alpha_max = norm(X_train.T @ y_train, ord=np.inf) / len(y_train)
alphas = alpha_max * np.geomspace(1, 1e-4)

Let’s train the estimator along the regularization path and then compute the MSE on train and test data.

mse_train = []
mse_test = []

clf = Lasso(fit_intercept=False, tol=1e-8, warm_start=True)
for idx, alpha in enumerate(alphas):
    clf.alpha = alpha
    clf.fit(X_train, y_train)

    mse_train.append(mean_squared_error(y_train, clf.predict(X_train)))
    mse_test.append(mean_squared_error(y_test, clf.predict(X_test)))

Finally, we can plot the train and test MSE. Notice the “sweet spot” at around 1e-4, which sits at the boundary between underfitting and overfitting.

plt.close('all')
plt.semilogx(alphas, mse_train, label='train MSE')
plt.semilogx(alphas, mse_test, label='test MSE')
plt.legend()
plt.title("Mean squared error")
plt.xlabel(r"Lasso regularization strength $\lambda$")
plt.show(block=False)
Mean squared error

Total running time of the script: (0 minutes 33.175 seconds)

Gallery generated by Sphinx-Gallery