Illustration of the GMM plan for the Mixture Wasserstein between two GMM in 1D, as well as the two maps T_mean and T_rand. T_mean is the barycentric projection of the GMM coupling, and T_rand takes a random gaussian image between two components, according to the coupling and the GMMs. See [69] for details. .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
Generate GMMOT plan plot itks = 2 kt = 3 d = 1 eps = 0.1 m_s = np.array([[1], [2]]) m_t = np.array([[3], [4.2], [5]]) C_s = np.array([[[0.05]], [[0.06]]]) C_t = np.array([[[0.03]], [[0.07]], [[0.04]]]) w_s = np.array([0.4, 0.6]) w_t = np.array([0.4, 0.2, 0.4]) n = 500 a_x, b_x = 0, 3 x = np.linspace(a_x, b_x, n) a_y, b_y = 2, 6 y = np.linspace(a_y, b_y, n) plan_density = gmm_ot_plan_density( x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2 ) a = gmm_pdf(x[:, None], m_s, C_s, w_s) b = gmm_pdf(y[:, None], m_t, C_t, w_t) plt.figure(figsize=(8, 8)) plot1D_mat( a, b, plan_density, title="GMM OT plan", plot_style="xy", a_label="Source distribution", b_label="Target distribution", )
(<Axes: title={'center': 'Source distribution'}>, <Axes: title={'center': 'Target distribution'}>, <Axes: title={'center': 'GMM OT plan'}>)Generate GMMOT maps and plot them over plan
plt.figure(figsize=(8, 8)) ax_s, ax_t, ax_M = plot1D_mat( a, b, plan_density, plot_style="xy", title="GMM OT plan with T_mean and T_rand maps", a_label="Source distribution", b_label="Target distribution", ) T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="bary")[:, 0] x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, a_y=a_y, b_y=b_y) ax_M.plot( x_rescaled, T_mean_rescaled, label="T_mean", alpha=0.5, linewidth=5, color="aqua" ) T_rand = gmm_ot_apply_map( x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method="rand", seed=0 )[:, 0] x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, a_y=a_y, b_y=b_y) ax_M.scatter( x_rescaled, T_rand_rescaled, label="T_rand", alpha=0.5, s=20, color="orange" ) ax_M.legend(loc="upper left", fontsize=13)
where plan > 0 (array([0, 1, 1]), array([0, 1, 2])) <matplotlib.legend.Legend object at 0x7f5907a75c30>
Total running time of the script: (0 minutes 0.305 seconds)
Gallery generated by Sphinx-Gallery
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4