1. What is Metric Learning?
Many approaches in machine learning require a measure of distance between data points. Traditionally, practitioners would choose a standard distance metric (Euclidean, City-Block, Cosine, etc.) using a priori knowledge of the domain. However, it is often difficult to design metrics that are well-suited to the particular data and task of interest.
Distance metric learning (or simply, metric learning) aims at automatically constructing task-specific distance metrics from (weakly) supervised data, in a machine learning manner. The learned distance metric can then be used to perform various tasks (e.g., k-NN classification, clustering, information retrieval).
1.1. Problem Setting
Metric learning problems fall into two main categories depending on the type of supervision available about the training data:
Supervised learning: the algorithm has access to a set of data points, each of them belonging to a class (label) as in a standard classification problem. Broadly speaking, the goal in this setting is to learn a distance metric that puts points with the same label close together while pushing away points with different labels.
Weakly supervised learning: the algorithm has access to a set of data points with supervision only at the tuple level (typically pairs, triplets, or quadruplets of data points). A classic example of such weaker supervision is a set of positive and negative pairs: in this case, the goal is to learn a distance metric that puts positive pairs close together and negative pairs far away.
Based on the above (weakly) supervised data, the metric learning problem is generally formulated as an optimization problem where one seeks to find the parameters of a distance function that optimize some objective function measuring the agreement with the training data.
1.2. Mahalanobis Distances
In the metric-learn package, all algorithms currently implemented learn
so-called Mahalanobis distances. Given a real-valued parameter matrix
\(L\) of shape (num_dims, n_features)
where n_features
is the
number features describing the data, the Mahalanobis distance associated with
\(L\) is defined as follows:
In other words, a Mahalanobis distance is a Euclidean distance after a
linear transformation of the feature space defined by \(L\) (taking
\(L\) to be the identity matrix recovers the standard Euclidean distance).
Mahalanobis distance metric learning can thus be seen as learning a new
embedding space of dimension num_dims
. Note that when num_dims
is
smaller than n_features
, this achieves dimensionality reduction.
Strictly speaking, Mahalanobis distances are “pseudo-metrics”: they satisfy three of the properties of a metric (non-negativity, symmetry, triangle inequality) but not necessarily the identity of indiscernibles.
Note
Mahalanobis distances can also be parameterized by a positive semi-definite (PSD) matrix \(M\):
Using the fact that a PSD matrix \(M\) can always be decomposed as \(M=L^\top L\) for some \(L\), one can show that both parameterizations are equivalent. In practice, an algorithm may thus solve the metric learning problem with respect to either \(M\) or \(L\).
1.3. Use-cases
There are many use-cases for metric learning. We list here a few popular examples (for code illustrating some of these use-cases, see the examples section of the documentation):
Nearest neighbors models: the learned metric can be used to improve nearest neighbors learning models for classification, regression, anomaly detection…
Clustering: metric learning provides a way to bias the clusters found by algorithms like K-Means towards the intended semantics.
Information retrieval: the learned metric can be used to retrieve the elements of a database that are semantically closest to a query element.
Dimensionality reduction: metric learning may be seen as a way to reduce the data dimension in a (weakly) supervised setting.
More generally, the learned transformation \(L\) can be used to project the data into a new embedding space before feeding it into another machine learning algorithm.
The API of metric-learn is compatible with scikit-learn, the leading library for machine learning in Python. This allows to easily pipeline metric learners with other scikit-learn estimators to realize the above use-cases, to perform joint hyperparameter tuning, etc.
1.4. Further reading
For more information about metric learning and its applications, one can refer to the following resources:
Tutorial: Similarity and Distance Metric Learning with Applications to Computer Vision (2015)
Surveys: A Survey on Metric Learning for Feature Vectors and Structured Data (2013), Metric Learning: A Survey (2012)
Book: Metric Learning (2015)