How to add a custom datafit#

Motivated by generalized linear models but not limited to it, skglm solves problems of the form

`\hat{\beta} \in \arg\min_{\beta \in \mathbb{R}^p} F(X\beta) + \Omega(\beta) := \sum_{i=1}^n f_i([X\beta]_i) + \sum_{j=1}^p \Omega_j(\beta_j) \ .`

Here, `X \in \mathbb{R}^{n \times p}` denotes the design matrix with `n` samples and `p` features, and `\beta \in \mathbb{R}^p` is the coefficient vector.

skglm can solve any problems of this form with arbitrary smooth datafit `F` and arbitrary penalty `\Omega` whose proximal operator can be evaluated explicitly, by defining two classes: a Penalty and a Datafit.

They can then be passed to a GeneralizedLinearEstimator.

clf = GeneralizedLinearEstimator(
   MyDatafit(),
   MyPenalty(),
)

A Datafit is a jitclass that must inherit from the BaseDatafit class:

class BaseDatafit:
    """Base class for datafits."""

    def get_spec(self):
        """Specify the numba types of the class attributes.

        Returns
        -------
        spec: Tuple of (attribute_name, dtype)
            spec to be passed to Numba jitclass to compile the class.
        """

    def params_to_dict(self):
        """Get the parameters to initialize an instance of the class.

        Returns
        -------
        dict_of_params : dict
            The parameters to instantiate an object of the class.
        """

    def initialize(self, X, y):
        """Pre-computations before fitting on X and y.

        Parameters
        ----------
        X : array, shape (n_samples, n_features)
            Design matrix.

        y : array, shape (n_samples,)
            Target vector.
        """

    def initialize_sparse(self, X_data, X_indptr, X_indices, y):
        """Pre-computations before fitting on X and y when X is a sparse matrix.

        Parameters
        ----------
        X_data : array, shape (n_elements,)
            `data` attribute of the sparse CSC matrix X.

        X_indptr : array, shape (n_features + 1,)
            `indptr` attribute of the sparse CSC matrix X.

        X_indices : array, shape (n_elements,)
            `indices` attribute of the sparse CSC matrix X.

        y : array, shape (n_samples,)
            Target vector.
        """

    def value(self, y, w, Xw):
        """Value of datafit at vector w.

        Parameters
        ----------
        y : array_like, shape (n_samples,)
            Target vector.

        w : array_like, shape (n_features,)
            Coefficient vector.

        Xw: array_like, shape (n_samples,)
            Model fit.

        Returns
        -------
        value : float
            The datafit value at vector w.
        """

To define a custom datafit, you need to inherit from BaseDatafit class and implement methods required by the targeted solver. These methods can be found in the solver documentation. Optionally, overloading the methods with the suffix _sparse adds support for sparse datasets (CSC matrix).

This tutorial shows how to implement Poisson datafit to be fitted with ProxNewton solver.

A case in point: defining Poisson datafit#

First, this requires deriving some quantities used by the solvers like the gradient or the Hessian matrix of the datafit. With `y \in \mathbb{R}^n` the target vector, the Poisson datafit reads

`f(X\beta) = \frac{1}{n}\sum_{i=1}^n \exp([X\beta]_i) - y_i[X\beta]_i \ .`

Let’s define some useful quantities to simplify our computations. For `z \in \mathbb{R}^n` and `\beta \in \mathbb{R}^p`,

`f(z) = \sum_{i=1}^n f_i(z_i) \qquad F(\beta) = f(X\beta) \ .`

Computing the gradient of `F` and its Hessian matrix yields

`\nabla F(\beta) = X^{\top} \underbrace{\nabla f(X\beta)}_"raw grad" \qquad \nabla^2 F(\beta) = X^{\top} \underbrace{\nabla^2 f(X\beta)}_"raw hessian" X \ .`

Besides, it directly follows that

`\nabla f(z) = (f_i^'(z_i))_{1 \leq i \leq n} \qquad \nabla^2 f(z) = "diag"(f_i^('')(z_i))_{1 \leq i \leq n} \ .`

We can now apply these definitions to the Poisson datafit:

`f_i(z_i) = \frac{1}{n} \left(\exp(z_i) - y_iz_i\right) \ .`

Therefore,

`f_i^('')(z_i) = \frac{1}{n}(\exp(z_i) - y_i) \qquad f^''_i(z_i) = \frac{1}{n}\exp(z_i) \ .`

Computing raw_grad and raw_hessian for the Poisson datafit yields

`\nabla f(X\beta) = \frac{1}{n}(\exp([X\beta]_i) - y_i)_{1 \leq i \leq n} \qquad \nabla^2 f(X\beta) = \frac{1}{n}"diag"(\exp([X\beta]_i))_{1 \leq i \leq n} \ .`

Both raw_grad and raw_hessian are methods used by the ProxNewton solver. But other optimizers require different methods to be implemented. For instance, AndersonCD uses the gradient_scalar method: it is the derivative of the datafit with respect to the `j`-th coordinate of `\beta`.

For the Poisson datafit, this yields

`\frac{\partial F(\beta)}{\partial \beta_j} = \frac{1}{n} \sum_{i=1}^n X_{i,j} \left( \exp([X\beta]_i) - y \right) \ .`

When implementing these quantities in the Poisson datafit class, this gives:

class Poisson(BaseDatafit):
    r"""Poisson datafit.

    The datafit reads:

    .. math:: 1 / n_"samples" sum_(i=1)^(n_"samples") (exp((Xw)_i) - y_i (Xw)_i)

    Notes
    -----
    The class is jit compiled at fit time using Numba compiler.
    This allows for faster computations.
    """

    def __init__(self):
        pass

    def get_spec(self):
        pass

    def params_to_dict(self):
        return dict()

    def initialize(self, X, y):
        if np.any(y <= 0):
            raise ValueError(
                "Target vector `y` should only take positive values " +
                "when fitting a Poisson model.")

    def initialize_sparse(self, X_data, X_indptr, X_indices, y):
        if np.any(y <= 0):
            raise ValueError(
                "Target vector `y` should only take positive values " +
                "when fitting a Poisson model.")

    def raw_grad(self, y, Xw):
        """Compute gradient of datafit w.r.t ``Xw``."""
        return (np.exp(Xw) - y) / len(y)

    def raw_hessian(self, y, Xw):
        """Compute Hessian of datafit w.r.t ``Xw``."""
        return np.exp(Xw) / len(y)

    def value(self, y, w, Xw):
        return np.sum(np.exp(Xw) - y * Xw) / len(y)

    def gradient_scalar(self, X, y, w, Xw, j):
        return (X[:, j] @ (np.exp(Xw) - y)) / len(y)

    def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
        n_features = X_indptr.shape[0] - 1
        grad = np.zeros(n_features, dtype=X_data.dtype)
        for j in range(n_features):
            grad[j] = 0.
            for i in range(X_indptr[j], X_indptr[j + 1]):
                grad[j] += X_data[i] * (
                    np.exp(Xw[X_indices[i]] - y[X_indices[i]])) / len(y)
        return grad

    def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
        grad = 0.
        for i in range(X_indptr[j], X_indptr[j + 1]):
            idx_i = X_indices[i]
            grad += X_data[i] * (np.exp(Xw[idx_i]) - y[idx_i])
        return grad / len(y)

    def intercept_update_step(self, y, Xw):
        return np.sum(self.raw_grad(y, Xw))

Note that we have not initialized any quantities in the initialize method. Usually, it serves to compute datafit attributes specific to a dataset X, y for computational efficiency, for example the computation of X.T @ y in Quadratic datafit.