A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/auto_examples/../gen_modules/../_modules/ot/factored.html below:

Website Navigation


ot.factored — POT Python Optimal Transport 0.9.5 documentation

"""
Factored OT solvers (low rank, cost or OT plan)
"""

# Author: Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License

from .backend import get_backend
from .utils import dist, get_lowrank_lazytensor
from .lp import emd
from .bregman import sinkhorn

__all__ = ["factored_optimal_transport"]



[docs]
def factored_optimal_transport(
    Xa,
    Xb,
    a=None,
    b=None,
    reg=0.0,
    r=100,
    X0=None,
    stopThr=1e-7,
    numItermax=100,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""Solves factored OT problem and return OT plans and intermediate distribution

    This function solve the following OT problem [40]_

    .. math::
        \mathop{\arg \min}_\mu \quad  W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)

    where :

    - :math:`\mu_a` and :math:`\mu_b`  are empirical distributions.
    - :math:`\mu` is an empirical distribution with r samples

    And returns the two OT plans between

    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends. But the algorithm uses the C++ CPU backend
        which can lead to copy overhead on GPU arrays.

    Uses the conditional gradient algorithm to solve the problem proposed in
    :ref:`[39] <references-weak>`.

    Parameters
    ----------
    Xa : (ns,d) array-like, float
        Source samples
    Xb : (nt,d) array-like, float
        Target samples
    a : (ns,) array-like, float
        Source histogram (uniform weight if empty list)
    b : (nt,) array-like, float
        Target histogram (uniform weight if empty list))
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on the relative variation (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    Ga: array-like, shape (ns, r)
        Optimal transportation matrix between source and the intermediate
        distribution
    Gb: array-like, shape (r, nt)
        Optimal transportation matrix between the intermediate and target
        distribution
    X: array-like, shape (r, d)
        Support of the intermediate distribution
    log: dict, optional
        If input log is true, a dictionary containing the cost and dual
        variables and exit status


    .. _references-factored:
    References
    ----------
    .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
        G., & Weed, J. (2019, April). Statistical optimal transport via factored
        couplings. In The 22nd International Conference on Artificial
        Intelligence and Statistics (pp. 2454-2465). PMLR.

    See Also
    --------
    ot.bregman.sinkhorn : Entropic regularized OT
    ot.optim.cg : General regularized OT
    """

    nx = get_backend(Xa, Xb)

    n_a = Xa.shape[0]
    n_b = Xb.shape[0]
    d = Xa.shape[1]

    if a is None:
        a = nx.ones((n_a), type_as=Xa) / n_a
    if b is None:
        b = nx.ones((n_b), type_as=Xb) / n_b

    if X0 is None:
        X = nx.randn(r, d, type_as=Xa)
    else:
        X = X0

    w = nx.ones(r, type_as=Xa) / r

    def solve_ot(X1, X2, w1, w2):
        M = dist(X1, X2)
        if reg > 0:
            G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
            log["cost"] = nx.sum(G * M)
            return G, log
        else:
            return emd(w1, w2, M, log=True, **kwargs)

    norm_delta = []

    # solve the barycenter
    for i in range(numItermax):
        old_X = X

        # solve OT with template
        Ga, loga = solve_ot(Xa, X, a, w)
        Gb, logb = solve_ot(X, Xb, w, b)

        X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r

        delta = nx.norm(X - old_X)
        if delta < stopThr:
            break
        if log:
            norm_delta.append(delta)

    if log:
        log_dic = {
            "delta_iter": norm_delta,
            "ua": loga["u"],
            "va": loga["v"],
            "ub": logb["u"],
            "vb": logb["v"],
            "costa": loga["cost"],
            "costb": logb["cost"],
            "lazy_plan": get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx),
        }
        return Ga, Gb, X, log_dic

    return Ga, Gb, X


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4