A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_d2.html below:

Website Navigation


OT for domain adaptation on empirical distributions — POT Python Optimal Transport 0.9.5 documentation

OT for domain adaptation on empirical distributions

This example introduces a domain adaptation in a 2D setting. It explicit the problem of domain adaptation and introduces some optimal transport approaches to solve it.

Quantities such as optimal couplings, greater coupling coefficients and transported samples are represented in order to give a visual understanding of what the transport methods are doing.

# Authors: Remi Flamary <remi.flamary@unice.fr>
#          Stanislas Chambon <stan.chambon@gmail.com>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import matplotlib.pylab as pl
import ot
import ot.plot
Generate data Instantiate the different transport algorithms and fit them
/home/circleci/project/ot/bregman/_sinkhorn.py:903: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(
/home/circleci/project/ot/bregman/_sinkhorn.py:667: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(
Fig 1 : plots source and target samples + matrix of pairwise distance
pl.figure(1, figsize=(10, 10))
pl.subplot(2, 2, 1)
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Source  samples")

pl.subplot(2, 2, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.legend(loc=0)
pl.title("Target samples")

pl.subplot(2, 2, 3)
pl.imshow(M, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Matrix of pairwise distances")
pl.tight_layout()
Fig 2 : plots optimal couplings for the different methods
pl.figure(2, figsize=(10, 6))

pl.subplot(2, 3, 1)
pl.imshow(ot_emd.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nEMDTransport")

pl.subplot(2, 3, 2)
pl.imshow(ot_sinkhorn.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSinkhornTransport")

pl.subplot(2, 3, 3)
pl.imshow(ot_lpl1.coupling_, interpolation="nearest")
pl.xticks([])
pl.yticks([])
pl.title("Optimal coupling\nSinkhornLpl1Transport")

pl.subplot(2, 3, 4)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nEMDTransport")

pl.subplot(2, 3, 5)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nSinkhornTransport")

pl.subplot(2, 3, 6)
ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[0.5, 0.5, 1])
pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples")
pl.xticks([])
pl.yticks([])
pl.title("Main coupling coefficients\nSinkhornLpl1Transport")
pl.tight_layout()
Fig 3 : plot transported samples
# display transported samples
pl.figure(4, figsize=(10, 4))
pl.subplot(1, 3, 1)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
    transp_Xs_emd[:, 0],
    transp_Xs_emd[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.title("Transported samples\nEmdTransport")
pl.legend(loc=0)
pl.xticks([])
pl.yticks([])

pl.subplot(1, 3, 2)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
    transp_Xs_sinkhorn[:, 0],
    transp_Xs_sinkhorn[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.title("Transported samples\nSinkhornTransport")
pl.xticks([])
pl.yticks([])

pl.subplot(1, 3, 3)
pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.5)
pl.scatter(
    transp_Xs_lpl1[:, 0],
    transp_Xs_lpl1[:, 1],
    c=ys,
    marker="+",
    label="Transp samples",
    s=30,
)
pl.title("Transported samples\nSinkhornLpl1Transport")
pl.xticks([])
pl.yticks([])

pl.tight_layout()
pl.show()

Total running time of the script: (0 minutes 6.911 seconds)

Gallery generated by Sphinx-Gallery


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4