This example illustrates the use of EWCA as proposed in [52].
[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). Entropic Wasserstein Component Analysis.
# Author: Antoine Collas <antoine.collas@inria.fr> # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl from ot.dr import ewca from sklearn.datasets import make_blobs from matplotlib import ticker as mticker import matplotlib.patches as patches import matplotlibGenerate data
n_samples = 20 esp = 0.8 centers = np.array([[esp, esp], [-esp, -esp]]) cluster_std = 0.4 rng = np.random.RandomState(42) X, y = make_blobs( n_samples=n_samples, n_features=2, centers=centers, cluster_std=cluster_std, shuffle=False, random_state=rng, ) X = X - X.mean(0)Plot data
fig = pl.figure(figsize=(4, 4)) cmap = matplotlib.colormaps.get_cmap("tab10") pl.scatter( X[: n_samples // 2, 0], X[: n_samples // 2, 1], color=[cmap(y[i] + 1) for i in range(n_samples // 2)], alpha=0.4, label="Class 1", zorder=30, s=50, ) pl.scatter( X[n_samples // 2 :, 0], X[n_samples // 2 :, 1], color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)], alpha=0.4, label="Class 2", zorder=30, s=50, ) x_y_lim = 2.5 fs = 15 pl.xlim(-x_y_lim, x_y_lim) pl.xticks([]) pl.ylim(-x_y_lim, x_y_lim) pl.yticks([]) pl.legend(fontsize=fs) pl.title("Data", fontsize=fs) pl.tight_layout()Compute EWCA Plot data, first component, and projected data
fig = pl.figure(figsize=(4, 4)) scale = 3 u = U[:, 0] pl.plot( [scale * u[0], -scale * u[0]], [scale * u[1], -scale * u[1]], color="grey", linestyle="--", lw=3, alpha=0.3, label=r"$\mathbf{U}$", ) X1 = X @ u[:, None] @ u[:, None].T for i in range(n_samples): for j in range(n_samples): v = pi[i, j] / pi.max() if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1): pl.plot( [X[i, 0], X1[j, 0]], [X[i, 1], X1[j, 1]], alpha=v, linestyle="-", c="C0", label=r"$\pi_{ij}$" if (i, j) == (n_samples - 1, n_samples - 1) else None, ) pl.scatter( X[:, 0], X[:, 1], color=[cmap(y[i] + 1) for i in range(n_samples)], alpha=0.4, label=r"$\mathbf{x}_i$", zorder=30, s=50, ) pl.scatter( X1[:, 0], X1[:, 1], color=[cmap(y[i] + 1) for i in range(n_samples)], alpha=0.9, s=50, marker="+", label=r"$\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_i$", zorder=30, ) pl.title("Data and projections", fontsize=fs) pl.xlim(-x_y_lim, x_y_lim) pl.xticks([]) pl.ylim(-x_y_lim, x_y_lim) pl.yticks([]) pl.legend(fontsize=fs, loc="upper left") pl.tight_layout()Plot transport plan
fig = pl.figure(figsize=(5, 5)) norm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100) im = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect="auto") cb = fig.colorbar(im, orientation="vertical", shrink=0.8) ticks_loc = cb.ax.get_yticks().tolist() cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) cb.ax.set_yticklabels([f"{int(i)}%" for i in cb.get_ticks()]) cb.ax.tick_params(labelsize=fs) for i, class_ in enumerate(np.sort(np.unique(y))): indices = y == class_ idx_min = np.min(np.arange(len(y))[indices]) idx_max = np.max(np.arange(len(y))[indices]) width = idx_max - idx_min + 1 rect = patches.Rectangle( (idx_min - 0.5, idx_min - 0.5), width, width, linewidth=1, edgecolor="r", facecolor="none", ) pl.gca().add_patch(rect) pl.title("OT plan", fontsize=fs) pl.ylabel(r"($\mathbf{x}_1, \cdots, \mathbf{x}_n$)") x_label = r"($\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_1, \cdots," x_label += r"\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_n$)" pl.xlabel(x_label) pl.tight_layout() pl.axis("scaled") pl.show()
Total running time of the script: (0 minutes 0.456 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4