Note
Go to the end to download the full example code.
Group Logistic regression in python¶
Scikit-learn is missing a Group Logistic regression estimator. We show how to implement
one with skglm
.
# Author: Mathurin Massias
import numpy as np
from skglm import GeneralizedLinearEstimator
from skglm.datafits import LogisticGroup
from skglm.penalties import WeightedGroupL2
from skglm.solvers import GroupProxNewton
from skglm.utils.data import make_correlated_data, grp_converter
n_features = 30
X, y, _ = make_correlated_data(
n_samples=10, n_features=30, random_state=0)
y = np.sign(y)
Classifier creation: combination of penalty, datafit and solver.
grp_size = 3 # groups are made of groups of 3 consecutive features
n_groups = n_features // grp_size
grp_indices, grp_ptr = grp_converter(grp_size, n_features=n_features)
alpha = 0.01
weights = np.ones(n_groups)
penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices)
datafit = LogisticGroup(grp_ptr, grp_indices)
solver = GroupProxNewton(verbose=2)
Train the model
Iteration 1: 0.6931471806, stopping crit: 5.93e-01
PN iteration 1: 0.1818203070, stopping crit in: 1.38e-01
Early exit
Iteration 2: 0.1818203070, stopping crit: 1.38e-01
PN iteration 1: 0.1130022685, stopping crit in: 4.59e-02
PN iteration 2: 0.0920956247, stopping crit in: 1.58e-02
Early exit
Iteration 3: 0.0920956247, stopping crit: 1.58e-02
PN iteration 1: 0.0846061796, stopping crit in: 6.29e-03
PN iteration 2: 0.0799359080, stopping crit in: 1.83e-03
Early exit
Iteration 4: 0.0799359080, stopping crit: 1.83e-03
PN iteration 1: 0.0794149142, stopping crit in: 1.96e-04
Early exit
Iteration 5: 0.0794149142, stopping crit: 1.96e-04
PN iteration 1: 0.0794078275, stopping crit in: 3.61e-05
Early exit
Iteration 6: 0.0794078275, stopping crit: 3.61e-05
Fit check that groups are either all 0 or all non zero
print(clf.coef_.reshape(-1, grp_size))
[[ 0. 0. 0. ]
[-1.10333799 -0.8924748 -0.77600881]
[-1.68695901 -1.26460422 -1.06153908]
[-0.13524978 0.01904796 -0.02993834]
[ 0. 0. 0. ]
[-0.02193371 -0.02044038 -0.00869443]
[-0.90363392 -0.96337048 -0.49341585]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0.57737623 -0.25394772 0.15405176]]
Total running time of the script: (0 minutes 14.870 seconds)