Sandwich demo

Sandwich demo based on code from http://nbviewer.ipython.org/6576096

Note

In order to show the charts of the examples you need a graphical matplotlib backend installed. For intance, use pip install pyqt5 to get Qt graphical interface or use your favorite one.

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors

from metric_learn import (LMNN, ITML_Supervised, LSML_Supervised,
                          SDML_Supervised)


def sandwich_demo():
  x, y = sandwich_data()
  knn = nearest_neighbors(x, k=2)
  ax = plt.subplot(3, 1, 1)  # take the whole top row
  plot_sandwich_data(x, y, ax)
  plot_neighborhood_graph(x, knn, y, ax)
  ax.set_title('input space')
  ax.set_aspect('equal')
  ax.set_xticks([])
  ax.set_yticks([])

  mls = [
      LMNN(),
      ITML_Supervised(n_constraints=200),
      SDML_Supervised(n_constraints=200, balance_param=0.001),
      LSML_Supervised(n_constraints=200),
  ]

  for ax_num, ml in enumerate(mls, start=3):
    ml.fit(x, y)
    tx = ml.transform(x)
    ml_knn = nearest_neighbors(tx, k=2)
    ax = plt.subplot(3, 2, ax_num)
    plot_sandwich_data(tx, y, axis=ax)
    plot_neighborhood_graph(tx, ml_knn, y, axis=ax)
    ax.set_title(ml.__class__.__name__)
    ax.set_xticks([])
    ax.set_yticks([])
  plt.show()


# TODO: use this somewhere
def visualize_class_separation(X, labels):
  _, (ax1, ax2) = plt.subplots(ncols=2)
  label_order = np.argsort(labels)
  ax1.imshow(pairwise_distances(X[label_order]), interpolation='nearest')
  ax2.imshow(pairwise_distances(labels[label_order, None]),
             interpolation='nearest')


def nearest_neighbors(X, k=5):
  knn = NearestNeighbors(n_neighbors=k)
  knn.fit(X)
  return knn.kneighbors(X, return_distance=False)


def sandwich_data():
  # number of distinct classes
  num_classes = 6
  # number of points per class
  num_points = 9
  # distance between layers, the points of each class are in a layer
  dist = 0.7

  data = np.zeros((num_classes, num_points, 2), dtype=float)
  labels = np.zeros((num_classes, num_points), dtype=int)

  x_centers = np.arange(num_points, dtype=float) - num_points / 2
  y_centers = dist * (np.arange(num_classes, dtype=float) - num_classes / 2)
  for i, yc in enumerate(y_centers):
    for k, xc in enumerate(x_centers):
      data[i, k, 0] = np.random.normal(xc, 0.1)
      data[i, k, 1] = np.random.normal(yc, 0.1)
    labels[i, :] = i
  return data.reshape((-1, 2)), labels.ravel()


def plot_sandwich_data(x, y, axis=plt, colors='rbgmky'):
  for idx, val in enumerate(np.unique(y)):
    xi = x[y == val]
    axis.scatter(*xi.T, s=50, facecolors='none', edgecolors=colors[idx])


def plot_neighborhood_graph(x, nn, y, axis=plt, colors='rbgmky'):
  for i, a in enumerate(x):
    b = x[nn[i, 1]]
    axis.plot((a[0], b[0]), (a[1], b[1]), colors[y[i]])


if __name__ == '__main__':
  sandwich_demo()
input space, LMNN, ITML_Supervised, SDML_Supervised, LSML_Supervised
/Users/william.vezelhes/metric-learn/metric_learn/constraints.py:224: UserWarning: Only generated 199 positive constraints (requested 200)
  warnings.warn("Only generated %d %s constraints (requested %d)" % (
/Users/william.vezelhes/miniconda3/envs/docenvbis/lib/python3.11/site-packages/sklearn/covariance/_graph_lasso.py:329: FutureWarning: The cov_init parameter is deprecated in 1.3 and will be removed in 1.5. It does not have any effect.
  warnings.warn(

Total running time of the script: (0 minutes 3.314 seconds)

Gallery generated by Sphinx-Gallery