This example introduces a domain adaptation in a 2D setting and OTDA approach with Laplacian regularization.
# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr> # License: MIT License import matplotlib.pylab as pl import otGenerate data Instantiate the different transport algorithms and fit them
/home/circleci/project/ot/bregman/_sinkhorn.py:903: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`. warnings.warn( /home/circleci/project/ot/backend.py:1165: RuntimeWarning: overflow encountered in exp return np.exp(a)Fig 1 : plots source and target samples
pl.figure(1, figsize=(10, 5)) pl.subplot(1, 2, 1) pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker="+", label="Source samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Source samples") pl.subplot(1, 2, 2) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples") pl.xticks([]) pl.yticks([]) pl.legend(loc=0) pl.title("Target samples") pl.tight_layout()Fig 2 : plot optimal couplings and transported samples
param_img = {"interpolation": "nearest"} pl.figure(2, figsize=(15, 8)) pl.subplot(2, 3, 1) pl.imshow(ot_emd.coupling_, **param_img) pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nEMDTransport") pl.figure(2, figsize=(15, 8)) pl.subplot(2, 3, 2) pl.imshow(ot_sinkhorn.coupling_, **param_img) pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nSinkhornTransport") pl.subplot(2, 3, 3) pl.imshow(ot_emd_laplace.coupling_, **param_img) pl.xticks([]) pl.yticks([]) pl.title("Optimal coupling\nEMDLaplaceTransport") pl.subplot(2, 3, 4) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) pl.scatter( transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.xticks([]) pl.yticks([]) pl.title("Transported samples\nEmdTransport") pl.legend(loc="lower left") pl.subplot(2, 3, 5) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) pl.scatter( transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.xticks([]) pl.yticks([]) pl.title("Transported samples\nSinkhornTransport") pl.subplot(2, 3, 6) pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker="o", label="Target samples", alpha=0.3) pl.scatter( transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, marker="+", label="Transp samples", s=30, ) pl.xticks([]) pl.yticks([]) pl.title("Transported samples\nEMDLaplaceTransport") pl.tight_layout() pl.show()
Total running time of the script: (0 minutes 1.786 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4