Group Logistic regression in python#

Scikit-learn is missing a Group Logistic regression estimator. We show how to implement one with skglm.

# Author: Mathurin Massias

import numpy as np

from skglm import GeneralizedLinearEstimator
from skglm.datafits import LogisticGroup
from skglm.penalties import WeightedGroupL2
from skglm.solvers import GroupProxNewton
from skglm.utils.data import make_correlated_data, grp_converter

n_features = 30
X, y, _ = make_correlated_data(
    n_samples=10, n_features=30, random_state=0)
y = np.sign(y)

Classifier creation: combination of penalty, datafit and solver.

grp_size = 3  # groups are made of groups of 3 consecutive features
n_groups = n_features // grp_size
grp_indices, grp_ptr = grp_converter(grp_size, n_features=n_features)
alpha = 0.01
weights = np.ones(n_groups)
penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices)
datafit = LogisticGroup(grp_ptr, grp_indices)
solver = GroupProxNewton(verbose=2)

Train the model

Iteration 1: 0.6931471806, stopping crit: 5.93e-01
PN iteration 1: 0.1818203070, stopping crit in: 1.38e-01
Early exit
Iteration 2: 0.1818203070, stopping crit: 1.38e-01
PN iteration 1: 0.1130022685, stopping crit in: 4.59e-02
PN iteration 2: 0.0920956247, stopping crit in: 1.58e-02
Early exit
Iteration 3: 0.0920956247, stopping crit: 1.58e-02
PN iteration 1: 0.0846061796, stopping crit in: 6.29e-03
PN iteration 2: 0.0799359080, stopping crit in: 1.83e-03
Early exit
Iteration 4: 0.0799359080, stopping crit: 1.83e-03
PN iteration 1: 0.0794149142, stopping crit in: 1.96e-04
Early exit
Iteration 5: 0.0794149142, stopping crit: 1.96e-04
PN iteration 1: 0.0794078275, stopping crit in: 3.61e-05
Early exit
Iteration 6: 0.0794078275, stopping crit: 3.61e-05
GeneralizedLinearEstimator(datafit=LogisticGroup, penalty=WeightedGroupL2, alpha=0.01)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Fit check that groups are either all 0 or all non zero

print(clf.coef_.reshape(-1, grp_size))
[[ 0.          0.          0.        ]
 [-1.10333799 -0.8924748  -0.77600881]
 [-1.68695901 -1.26460422 -1.06153908]
 [-0.13524978  0.01904796 -0.02993834]
 [ 0.          0.          0.        ]
 [-0.02193371 -0.02044038 -0.00869443]
 [-0.90363392 -0.96337048 -0.49341585]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 [ 0.57737623 -0.25394772  0.15405176]]

Total running time of the script: (0 minutes 13.846 seconds)

Gallery generated by Sphinx-Gallery