A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/auto_examples/backends/plot_dual_ot_pytorch.html below:

Website Navigation


Dual OT solvers for entropic and quadratic regularized OT with Pytorch — POT Python Optimal Transport 0.9.5 documentation

Dual OT solvers for entropic and quadratic regularized OT with Pytorch
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import numpy as np
import matplotlib.pyplot as pl
import torch
import ot
import ot.plot
Data generation Plot data
pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples")
pl.legend(loc=0)
pl.title("Source and target distributions")
Text(0.5, 1.0, 'Source and target distributions')
Convert data to torch tensors Estimating dual variables for entropic OT
u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.5

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []

for i in range(n_iter):
    # generate noise samples

    # minus because we maximize the dual loss
    loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
    losses.append(float(loss.detach()))

    if i % 10 == 0:
        print("Iter: {:3d}, loss={}".format(i, losses[-1]))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title("Dual objective (negative)")
pl.xlabel("Iterations")

Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
Iter:   0, loss=0.202049490022473
Iter:  10, loss=-19.5949576048274
Iter:  20, loss=-31.38859407924335
Iter:  30, loss=-35.900700994883316
Iter:  40, loss=-39.04775322975264
Iter:  50, loss=-40.98205744442137
Iter:  60, loss=-41.76287033773109
Iter:  70, loss=-42.09317938106601
Iter:  80, loss=-42.16283642689691
Iter:  90, loss=-42.2036823596773
Iter: 100, loss=-42.21953889248053
Iter: 110, loss=-42.22807022489665
Iter: 120, loss=-42.23272560534595
Iter: 130, loss=-42.2358250550116
Iter: 140, loss=-42.238243196354176
Iter: 150, loss=-42.24043195312365
Iter: 160, loss=-42.242376559312795
Iter: 170, loss=-42.24411141177066
Iter: 180, loss=-42.24563020403166
Iter: 190, loss=-42.24691854142809
Plot the estimated entropic OT plan
pl.figure(3, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2)
pl.legend(loc=0)
pl.title("Source and target distributions")
Text(0.5, 1.0, 'Source and target distributions')
Estimating dual variables for quadratic OT
u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.01

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []


for i in range(n_iter):
    # generate noise samples

    # minus because we maximize the dual loss
    loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
    losses.append(float(loss.detach()))

    if i % 10 == 0:
        print("Iter: {:3d}, loss={}".format(i, losses[-1]))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


pl.figure(4)
pl.plot(losses)
pl.grid()
pl.title("Dual objective (negative)")
pl.xlabel("Iterations")

Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
Iter:   0, loss=-0.0018442196020623663
Iter:  10, loss=-19.56815819740162
Iter:  20, loss=-31.14528043241069
Iter:  30, loss=-35.537678548240336
Iter:  40, loss=-38.764773617302076
Iter:  50, loss=-40.73846652972734
Iter:  60, loss=-41.52021186603003
Iter:  70, loss=-41.93972887343822
Iter:  80, loss=-42.00745415035667
Iter:  90, loss=-42.07627382396151
Iter: 100, loss=-42.10385212095502
Iter: 110, loss=-42.113487369686055
Iter: 120, loss=-42.11859124908843
Iter: 130, loss=-42.1215496208666
Iter: 140, loss=-42.12376522981474
Iter: 150, loss=-42.1258767523867
Iter: 160, loss=-42.12795577501417
Iter: 170, loss=-42.13003732914064
Iter: 180, loss=-42.13196074880831
Iter: 190, loss=-42.13366053904776
Plot the estimated quadratic OT plan
pl.figure(5, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2)
pl.legend(loc=0)
pl.title("OT plan with quadratic regularization")
Text(0.5, 1.0, 'OT plan with quadratic regularization')

Total running time of the script: (0 minutes 11.460 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4