A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_OT.html below:

Website Navigation


2D examples of exact and entropic unbalanced optimal transport — POT Python Optimal Transport 0.9.5 documentation

2D examples of exact and entropic unbalanced optimal transport

This example is designed to show how to compute unbalanced and partial OT in POT.

UOT aims at solving the following optimization problem:

\[ \begin{align}\begin{aligned}W = \min_{\gamma} <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})\\s.t. \gamma \geq 0\end{aligned}\end{align} \]

where \(\mathrm{div}\) is a divergence. When using the entropic UOT, \(\mathrm{reg}>0\) and \(\mathrm{div}\) should be the Kullback-Leibler divergence. When solving exact UOT, \(\mathrm{reg}=0\) and \(\mathrm{div}\) can be either the Kullback-Leibler or the quadratic divergence. Using \(\ell_1\) norm gives the so-called partial OT.

# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot
Generate data
n = 40  # nb samples

mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -0.8], [-0.8, 1]])

np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

n_noise = 10

xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0)
xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0)

n = n + n_noise

a, b = np.ones((n,)) / n, np.ones((n,)) / n  # uniform distribution on samples

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()
Compute entropic kl-regularized UOT, kl- and l2-regularized UOT
reg = 0.005
reg_m_kl = 0.05
reg_m_l2 = 5
mass = 0.7

entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)
kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div="kl")
l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div="l2")
partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)
Plot the results
pl.figure(2)
transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]
title = [
    "partial OT \n m=" + str(mass),
    "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + str(reg_m_l2),
    "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
    "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl),
]

for p in range(4):
    pl.subplot(2, 4, p + 1)
    P = transp[p]
    if P.sum() > 0:
        P = P / P.max()
    for i in range(n):
        for j in range(n):
            if P[i, j] > 0:
                pl.plot(
                    [xs[i, 0], xt[j, 0]],
                    [xs[i, 1], xt[j, 1]],
                    color="C2",
                    alpha=P[i, j] * 0.3,
                )
    pl.scatter(xs[:, 0], xs[:, 1], c="C0", alpha=0.2)
    pl.scatter(xt[:, 0], xt[:, 1], c="C1", alpha=0.2)
    pl.scatter(xs[:, 0], xs[:, 1], c="C0", s=P.sum(1).ravel() * (1 + p) * 2)
    pl.scatter(xt[:, 0], xt[:, 1], c="C1", s=P.sum(0).ravel() * (1 + p) * 2)
    pl.title(title[p])
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("mappings")
    pl.subplot(2, 4, p + 5)
    pl.imshow(P, cmap="jet")
    pl.yticks(())
    pl.xticks(())
    if p < 1:
        pl.ylabel("transport plans")
pl.show()

Total running time of the script: (0 minutes 3.120 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4