Timing comparison with scikit-learn for Lasso

Compare time to solve large scale Lasso problems with scikit-learn.

lasso, enet
file_sizes:   0%|                                   | 0.00/26.8M [00:00<?, ?B/s]
file_sizes:   0%|                           | 24.6k/26.8M [00:00<04:14, 105kB/s]
file_sizes:   0%|                           | 49.2k/26.8M [00:00<04:14, 105kB/s]
file_sizes:   0%|                            | 106k/26.8M [00:00<02:38, 169kB/s]
file_sizes:   1%|▏                           | 221k/26.8M [00:00<01:29, 295kB/s]
file_sizes:   2%|▍                           | 451k/26.8M [00:01<00:48, 542kB/s]
file_sizes:   3%|▉                          | 909k/26.8M [00:01<00:25, 1.02MB/s]
file_sizes:   7%|█▊                        | 1.83M/26.8M [00:01<00:12, 1.97MB/s]
file_sizes:  14%|███▌                      | 3.66M/26.8M [00:01<00:06, 3.83MB/s]
file_sizes:  20%|█████                     | 5.23M/26.8M [00:02<00:04, 4.71MB/s]
file_sizes:  33%|████████▋                 | 8.90M/26.8M [00:02<00:02, 8.06MB/s]
file_sizes:  41%|██████████▋               | 11.0M/26.8M [00:02<00:01, 8.38MB/s]
file_sizes:  55%|██████████████▏           | 14.7M/26.8M [00:02<00:01, 10.4MB/s]
file_sizes:  68%|█████████████████▊        | 18.3M/26.8M [00:03<00:00, 11.8MB/s]
file_sizes:  82%|█████████████████████▎    | 22.0M/26.8M [00:03<00:00, 12.8MB/s]
file_sizes:  96%|████████████████████████▉ | 25.7M/26.8M [00:03<00:00, 13.4MB/s]
file_sizes: 100%|██████████████████████████| 26.8M/26.8M [00:03<00:00, 7.49MB/s]
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
/home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(

import time
import warnings
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from libsvmdata import fetch_libsvm

from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import Lasso as Lasso_sklearn
from sklearn.linear_model import ElasticNet as Enet_sklearn

from skglm import Lasso, ElasticNet

warnings.filterwarnings('ignore', category=ConvergenceWarning)


def compute_obj(X, y, w, alpha, l1_ratio=1):
    loss = norm(y - X @ w) ** 2 / (2 * len(y))
    penalty = (alpha * l1_ratio * np.sum(np.abs(w))
               + 0.5 * alpha * (1 - l1_ratio) * norm(w) ** 2)
    return loss + penalty


X, y = fetch_libsvm("news20.binary"
                    )
alpha = np.max(np.abs(X.T @ y)) / len(y) / 10

dict_sklearn = {}
dict_sklearn["lasso"] = Lasso_sklearn(
    alpha=alpha, fit_intercept=False, tol=1e-12)

dict_sklearn["enet"] = Enet_sklearn(
    alpha=alpha, fit_intercept=False, tol=1e-12, l1_ratio=0.5)

dict_ours = {}
dict_ours["lasso"] = Lasso(
    alpha=alpha, fit_intercept=False, tol=1e-12)
dict_ours["enet"] = ElasticNet(
    alpha=alpha, fit_intercept=False, tol=1e-12, l1_ratio=0.5)

models = ["lasso", "enet"]

fig, axarr = plt.subplots(2, 1, constrained_layout=True)

for ax, model, l1_ratio in zip(axarr, models, [1, 0.5]):
    pobj_dict = {}
    pobj_dict["sklearn"] = list()
    pobj_dict["us"] = list()

    time_dict = {}
    time_dict["sklearn"] = list()
    time_dict["us"] = list()

    # Remove compilation time
    dict_ours[model].max_iter = 10_000
    w_star = dict_ours[model].fit(X, y).coef_
    pobj_star = compute_obj(X, y, w_star, alpha, l1_ratio)
    for n_iter_sklearn in np.unique(np.geomspace(1, 50, num=15).astype(int)):
        dict_sklearn[model].max_iter = n_iter_sklearn

        t_start = time.time()
        w_sklearn = dict_sklearn[model].fit(X, y).coef_
        time_dict["sklearn"].append(time.time() - t_start)
        pobj_dict["sklearn"].append(compute_obj(X, y, w_sklearn, alpha, l1_ratio))

    for n_iter_us in range(1, 10):
        dict_ours[model].max_iter = n_iter_us
        t_start = time.time()
        w = dict_ours[model].fit(X, y).coef_
        time_dict["us"].append(time.time() - t_start)
        pobj_dict["us"].append(compute_obj(X, y, w, alpha, l1_ratio))

    ax.semilogy(
        time_dict["sklearn"], pobj_dict["sklearn"] - pobj_star, label='sklearn')
    ax.semilogy(
        time_dict["us"], pobj_dict["us"] - pobj_star, label='skglm')

    ax.set_ylim((1e-10, 1))
    ax.set_title(model)
    ax.legend()
    ax.set_ylabel("Objective suboptimality")

axarr[1].set_xlabel("Time (s)")
plt.show(block=False)

Total running time of the script: (2 minutes 17.597 seconds)

Gallery generated by Sphinx-Gallery