Illustration of Gaussian Bures-Wasserstein barycenters.
# Authors: Rémi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License # sphinx_gallery_thumbnail_number = 2
from matplotlib import colors from matplotlib.patches import Ellipse import numpy as np import matplotlib.pylab as pl import otDefine Gaussian Covariances and distributions
C1 = np.array([[0.5, -0.4], [-0.4, 0.5]]) C2 = np.array([[1, 0.3], [0.3, 1]]) C3 = np.array([[1.5, 0], [0, 0.5]]) C4 = np.array([[0.5, 0], [0, 1.5]]) C = np.stack((C1, C2, C3, C4)) m1 = np.array([0, 0]) m2 = np.array([0, 4]) m3 = np.array([4, 0]) m4 = np.array([4, 4]) m = np.stack((m1, m2, m3, m4))Plot the distributions
def draw_cov(mu, C, color=None, label=None, nstd=1): def eigsorted(cov): vals, vecs = np.linalg.eigh(cov) order = vals.argsort()[::-1] return vals[order], vecs[:, order] vals, vecs = eigsorted(C) theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) w, h = 2 * nstd * np.sqrt(vals) ell = Ellipse( xy=(mu[0], mu[1]), width=w, height=h, alpha=0.5, angle=theta, facecolor=color, edgecolor=color, label=label, fill=True, ) pl.gca().add_artist(ell) # pl.scatter(mu[0],mu[1],color=color, marker='x') axis = [-1.5, 5.5, -1.5, 5.5] pl.figure(1, (8, 2)) pl.clf() pl.subplot(1, 4, 1) draw_cov(m1, C1, color="C0") pl.axis(axis) pl.title("$\mathcal{N}(m_1,\Sigma_1)$") pl.subplot(1, 4, 2) draw_cov(m2, C2, color="C1") pl.axis(axis) pl.title("$\mathcal{N}(m_2,\Sigma_2)$") pl.subplot(1, 4, 3) draw_cov(m3, C3, color="C2") pl.axis(axis) pl.title("$\mathcal{N}(m_3,\Sigma_3)$") pl.subplot(1, 4, 4) draw_cov(m4, C4, color="C3") pl.axis(axis) pl.title("$\mathcal{N}(m_4,\Sigma_4)$")
Text(0.5, 1.0, '$\\mathcal{N}(m_4,\\Sigma_4)$')Compute Bures-Wasserstein barycenters and plot them
# basis for bilinear interpolation v1 = np.array((1, 0, 0, 0)) v2 = np.array((0, 1, 0, 0)) v3 = np.array((0, 0, 1, 0)) v4 = np.array((0, 0, 0, 1)) colors = np.stack( (colors.to_rgb("C0"), colors.to_rgb("C1"), colors.to_rgb("C2"), colors.to_rgb("C3")) ) pl.figure(2, (8, 8)) nb_interp = 6 for i in range(nb_interp): for j in range(nb_interp): tx = float(i) / (nb_interp - 1) ty = float(j) / (nb_interp - 1) # weights are constructed by bilinear interpolation tmp1 = (1 - tx) * v1 + tx * v2 tmp2 = (1 - tx) * v3 + tx * v4 weights = (1 - ty) * tmp1 + ty * tmp2 color = np.dot(colors.T, weights) mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights) draw_cov(mb, Cb, color=color, label=None, nstd=0.3) pl.axis(axis) pl.axis("off") pl.tight_layout()
Total running time of the script: (0 minutes 0.333 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4