# Author: Remi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pyplot as pl import torch import ot import ot.plotData generation Plot data
pl.figure(1, (10, 5)) pl.clf() pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples") pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples") pl.legend(loc=0) pl.title("Source and target distributions")
Text(0.5, 1.0, 'Source and target distributions')Convert data to torch tensors Estimating dual variables for entropic OT
u = torch.randn(n_source_samples, requires_grad=True) v = torch.randn(n_source_samples, requires_grad=True) reg = 0.5 optimizer = torch.optim.Adam([u, v], lr=1) # number of iteration n_iter = 200 losses = [] for i in range(n_iter): # generate noise samples # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(2) pl.plot(losses) pl.grid() pl.title("Dual objective (negative)") pl.xlabel("Iterations") Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
Iter: 0, loss=0.202049490022473 Iter: 10, loss=-19.5949576048274 Iter: 20, loss=-31.38859407924335 Iter: 30, loss=-35.900700994883316 Iter: 40, loss=-39.04775322975264 Iter: 50, loss=-40.98205744442137 Iter: 60, loss=-41.76287033773109 Iter: 70, loss=-42.09317938106601 Iter: 80, loss=-42.16283642689691 Iter: 90, loss=-42.2036823596773 Iter: 100, loss=-42.21953889248053 Iter: 110, loss=-42.22807022489665 Iter: 120, loss=-42.23272560534595 Iter: 130, loss=-42.2358250550116 Iter: 140, loss=-42.238243196354176 Iter: 150, loss=-42.24043195312365 Iter: 160, loss=-42.242376559312795 Iter: 170, loss=-42.24411141177066 Iter: 180, loss=-42.24563020403166 Iter: 190, loss=-42.24691854142809Plot the estimated entropic OT plan
pl.figure(3, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2) pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2) pl.legend(loc=0) pl.title("Source and target distributions")
Text(0.5, 1.0, 'Source and target distributions')Estimating dual variables for quadratic OT
u = torch.randn(n_source_samples, requires_grad=True) v = torch.randn(n_source_samples, requires_grad=True) reg = 0.01 optimizer = torch.optim.Adam([u, v], lr=1) # number of iteration n_iter = 200 losses = [] for i in range(n_iter): # generate noise samples # minus because we maximize the dual loss loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) losses.append(float(loss.detach())) if i % 10 == 0: print("Iter: {:3d}, loss={}".format(i, losses[-1])) loss.backward() optimizer.step() optimizer.zero_grad() pl.figure(4) pl.plot(losses) pl.grid() pl.title("Dual objective (negative)") pl.xlabel("Iterations") Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
Iter: 0, loss=-0.0018442196020623663 Iter: 10, loss=-19.56815819740162 Iter: 20, loss=-31.14528043241069 Iter: 30, loss=-35.537678548240336 Iter: 40, loss=-38.764773617302076 Iter: 50, loss=-40.73846652972734 Iter: 60, loss=-41.52021186603003 Iter: 70, loss=-41.93972887343822 Iter: 80, loss=-42.00745415035667 Iter: 90, loss=-42.07627382396151 Iter: 100, loss=-42.10385212095502 Iter: 110, loss=-42.113487369686055 Iter: 120, loss=-42.11859124908843 Iter: 130, loss=-42.1215496208666 Iter: 140, loss=-42.12376522981474 Iter: 150, loss=-42.1258767523867 Iter: 160, loss=-42.12795577501417 Iter: 170, loss=-42.13003732914064 Iter: 180, loss=-42.13196074880831 Iter: 190, loss=-42.13366053904776Plot the estimated quadratic OT plan
pl.figure(5, (10, 5)) pl.clf() ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2) pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2) pl.legend(loc=0) pl.title("OT plan with quadratic regularization")
Text(0.5, 1.0, 'OT plan with quadratic regularization')
Total running time of the script: (0 minutes 11.460 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4