A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/_modules/ot/unbalanced/_sinkhorn.html below:

Website Navigation


ot.unbalanced._sinkhorn — POT Python Optimal Transport 0.9.5 documentation

# -*- coding: utf-8 -*-
"""
Regularized Unbalanced OT solvers
"""

# Author: Hicham Janati <hicham.janati@inria.fr>
#         Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#         Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

from __future__ import division
import warnings

from ..backend import get_backend
from ..utils import list_to_array, get_parameter_pair



[docs]
def sinkhorn_unbalanced(
    a,
    b,
    M,
    reg,
    reg_m,
    method="sinkhorn",
    reg_type="kl",
    c=None,
    warmstart=None,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""
    Solve the unbalanced entropic regularization optimal transport problem
    and return the OT plan

    The function solves the following optimization problem:

    .. math::
        W = \arg \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) +
        \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - :math:`\mathbf{c}` is a reference distribution for the regularization
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
        If `a` is an empty list or array ([]),
        then `a` is set to uniform distribution.
    b : array-like (dim_b,)
        One or multiple unnormalized histograms of dimension `dim_b`.
        If `b` is an empty list or array ([]),
        then `b` is set to uniform distribution.
        If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i`
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float or indexable object of length 1 or 2
        Marginal relaxation term.
        If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
        then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
        The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`.
        For semi-relaxed case, use either
        :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or
        :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`.
        If :math:`\mathrm{reg_{m}}` is an array,
        it must have the same backend as input arrays `(a, b, M)`.
    method : str
        method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', 'sinkhorn_translation_invariant' or
        'sinkhorn_reg_scaling', see those function for specific parameters
    reg_type : string, optional
        Regularizer term. Can take two values:
        + Negative entropy: 'entropy':
        :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`.
        This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`.
        + Kullback-Leibler divergence: 'kl':
        :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
    c : array-like (dim_a, dim_b), optional (default=None)
        Reference measure for the regularization.
        If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
        If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the `u`, `v` sinkhorn scaling vectors).
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_distance : (n_hists,) array-like
            the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> import numpy as np
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> np.round(ot.sinkhorn_unbalanced(a, b, M, 1, 1), 7)
    array([[0.3220536, 0.1184769],
           [0.1184769, 0.3220536]])

    .. _references-sinkhorn-unbalanced:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
        Transport, Advances in Neural Information Processing Systems
        (NIPS) 26, 2013

    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
        Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.

    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
       Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
       In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.


    See Also
    --------
    ot.unbalanced.sinkhorn_knopp_unbalanced: Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>`
    ot.unbalanced.sinkhorn_stabilized_unbalanced:
        Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>`
    ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
        Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>`
    ot.unbalanced.sinkhorn_unbalanced_translation_invariant:
        Translation Invariant Unbalanced Sinkhorn :ref:`[73] <references-sinkhorn-unbalanced-translation-invariant>`

    """

    if method.lower() == "sinkhorn":
        return sinkhorn_knopp_unbalanced(
            a,
            b,
            M,
            reg,
            reg_m,
            reg_type,
            c,
            warmstart,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )

    elif method.lower() == "sinkhorn_stabilized":
        return sinkhorn_stabilized_unbalanced(
            a,
            b,
            M,
            reg,
            reg_m,
            reg_type,
            c,
            warmstart,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )

    elif method.lower() == "sinkhorn_translation_invariant":
        return sinkhorn_unbalanced_translation_invariant(
            a,
            b,
            M,
            reg,
            reg_m,
            reg_type,
            c,
            warmstart,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )

    elif method.lower() in ["sinkhorn_reg_scaling"]:
        warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp")
        return sinkhorn_knopp_unbalanced(
            a,
            b,
            M,
            reg,
            reg_m,
            reg_type,
            c,
            warmstart,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )
    else:
        raise ValueError("Unknown method '%s'." % method)




[docs]
def sinkhorn_unbalanced2(
    a,
    b,
    M,
    reg,
    reg_m,
    method="sinkhorn",
    reg_type="kl",
    c=None,
    warmstart=None,
    returnCost="linear",
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""
    Solve the entropic regularization unbalanced optimal transport problem and
    return the cost

    The function solves the following optimization problem:

    .. math::
        \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) +
        \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma\geq 0
    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - :math:`\mathbf{c}` is a reference distribution for the regularization
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
        If `a` is an empty list or array ([]),
        then `a` is set to uniform distribution.
    b : array-like (dim_b,)
        One or multiple unnormalized histograms of dimension `dim_b`.
        If `b` is an empty list or array ([]),
        then `b` is set to uniform distribution.
        If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i`
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float or indexable object of length 1 or 2
        Marginal relaxation term.
        If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
        then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
        The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`.
        For semi-relaxed case, use either
        :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or
        :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`.
        If :math:`\mathrm{reg_{m}}` is an array,
        it must have the same backend as input arrays `(a, b, M)`.
    method : str
        method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', 'sinkhorn_translation_invariant' or
        'sinkhorn_reg_scaling', see those function for specific parameters
    reg_type : string, optional
        Regularizer term. Can take two values:
        + Negative entropy: 'entropy':
        :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`.
        This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`.
        + Kullback-Leibler divergence: 'kl':
        :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
    c : array-like (dim_a, dim_b), optional (default=None)
        Reference measure for the regularization.
        If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
        If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors).
    returnCost: string, optional (default = "linear")
        If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
        If `returnCost` = "total", then return the total unbalanced OT loss.
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    ot_cost : (n_hists,) array-like
        the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
    log : dict
        log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> import numpy as np
    >>> a=[.5, .10]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8)
    0.19600125

    .. _references-sinkhorn-unbalanced2:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
        Transport, Advances in Neural Information Processing Systems
        (NIPS) 26, 2013

    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
        Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.

    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
       Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
       In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.

    See Also
    --------
    ot.unbalanced.sinkhorn_knopp: Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>`
    ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
    ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
    ot.unbalanced.sinkhorn_unbalanced_translation_invariant: Translation Invariant Unbalanced Sinkhorn :ref:`[73] <references-sinkhorn-unbalanced2>`

    """
    M, a, b = list_to_array(M, a, b)

    if len(b.shape) < 2:
        if method.lower() == "sinkhorn":
            res = sinkhorn_knopp_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=True,
                **kwargs,
            )

        elif method.lower() == "sinkhorn_stabilized":
            res = sinkhorn_stabilized_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=True,
                **kwargs,
            )

        elif method.lower() == "sinkhorn_translation_invariant":
            res = sinkhorn_unbalanced_translation_invariant(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=True,
                **kwargs,
            )

        elif method.lower() in ["sinkhorn_reg_scaling"]:
            warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp")
            res = sinkhorn_knopp_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=True,
                **kwargs,
            )
        else:
            raise ValueError("Unknown method %s." % method)

        if returnCost == "linear":
            cost = res[1]["cost"]
        elif returnCost == "total":
            cost = res[1]["total_cost"]
        else:
            raise ValueError("Unknown returnCost = {}".format(returnCost))

        if log:
            return cost, res[1]
        else:
            return cost

    else:
        if reg_type == "kl":
            warnings.warn("Reg_type not implemented yet. Use entropy.")

        if method.lower() == "sinkhorn":
            return sinkhorn_knopp_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                **kwargs,
            )

        elif method.lower() == "sinkhorn_stabilized":
            return sinkhorn_stabilized_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                **kwargs,
            )

        elif method.lower() == "sinkhorn_translation_invariant":
            return sinkhorn_unbalanced_translation_invariant(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                **kwargs,
            )

        elif method.lower() in ["sinkhorn_reg_scaling"]:
            warnings.warn("Method not implemented yet. Using classic Sinkhorn-Knopp")
            return sinkhorn_knopp_unbalanced(
                a,
                b,
                M,
                reg,
                reg_m,
                reg_type,
                c,
                warmstart,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                **kwargs,
            )
        else:
            raise ValueError("Unknown method %s." % method)




[docs]
def sinkhorn_knopp_unbalanced(
    a,
    b,
    M,
    reg,
    reg_m,
    reg_type="kl",
    c=None,
    warmstart=None,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""
    Solve the entropic regularization unbalanced optimal transport problem and
    return the OT plan

    The function solves the following optimization problem:

    .. math::
        W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) +
        \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - :math:`\mathbf{c}` is a reference distribution for the regularization
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
        If `a` is an empty list or array ([]),
        then `a` is set to uniform distribution.
    b : array-like (dim_b,)
        One or multiple unnormalized histograms of dimension `dim_b`.
        If `b` is an empty list or array ([]),
        then `b` is set to uniform distribution.
        If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i`
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float or indexable object of length 1 or 2
        Marginal relaxation term.
        If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
        then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
        The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`.
        For semi-relaxed case, use either
        :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or
        :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`.
        If :math:`\mathrm{reg_{m}}` is an array,
        it must have the same backend as input arrays `(a, b, M)`.
    reg_type : string, optional
        Regularizer term. Can take two values:
        + Negative entropy: 'entropy':
        :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`.
        This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`.
        + Kullback-Leibler divergence: 'kl':
        :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
    c : array-like (dim_a, dim_b), optional (default=None)
        Reference measure for the regularization.
        If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
        If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the `u`, `v` sinkhorn scaling vectors).
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_cost : (n_hists,) array-like
            the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> import numpy as np
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> np.round(ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.), 7)
    array([[0.3220536, 0.1184769],
           [0.1184769, 0.3220536]])

    .. _references-sinkhorn-knopp-unbalanced:
    References
    ----------
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT

    """
    M, a, b = list_to_array(M, a, b)
    nx = get_backend(M, a, b)

    dim_a, dim_b = M.shape

    if len(a) == 0:
        a = nx.ones(dim_a, type_as=M) / dim_a
    if len(b) == 0:
        b = nx.ones(dim_b, type_as=M) / dim_b

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    reg_m1, reg_m2 = get_parameter_pair(reg_m)

    if log:
        dict_log = {"err": []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        if n_hists:
            u = nx.ones((dim_a, 1), type_as=M)
            v = nx.ones((dim_b, n_hists), type_as=M)
            a = a.reshape(dim_a, 1)
        else:
            u = nx.ones(dim_a, type_as=M)
            v = nx.ones(dim_b, type_as=M)
    else:
        u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])

    if reg_type.lower() == "entropy":
        warnings.warn(
            "If reg_type = entropy, then the matrix c is overwritten by the one matrix."
        )
        c = nx.ones((dim_a, dim_b), type_as=M)

    if n_hists:
        K = nx.exp(-M / reg)
    else:
        c = a[:, None] * b[None, :] if c is None else c
        K = nx.exp(-M / reg) * c

    fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1
    fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1

    err = 1.0

    for i in range(numItermax):
        uprev = u
        vprev = v

        Kv = nx.dot(K, v)
        u = (a / Kv) ** fi_1
        Ktu = nx.dot(K.T, u)
        v = (b / Ktu) ** fi_2

        if (
            nx.any(Ktu == 0.0)
            or nx.any(nx.isnan(u))
            or nx.any(nx.isnan(v))
            or nx.any(nx.isinf(u))
            or nx.any(nx.isinf(v))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %s" % i)
            u = uprev
            v = vprev
            break

        err_u = nx.max(nx.abs(u - uprev)) / max(
            nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0
        )
        err_v = nx.max(nx.abs(v - vprev)) / max(
            nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0
        )
        err = 0.5 * (err_u + err_v)
        if log:
            dict_log["err"].append(err)
            if verbose:
                if i % 50 == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(i, err))
        if err < stopThr:
            break

    if log:
        dict_log["logu"] = nx.log(u + 1e-300)
        dict_log["logv"] = nx.log(v + 1e-300)

    if n_hists:  # return only loss
        res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M)
        if log:
            return res, dict_log
        else:
            return res

    else:  # return OT matrix
        plan = u[:, None] * K * v[None, :]

        if log:
            linear_cost = nx.sum(plan * M)
            dict_log["cost"] = linear_cost

            total_cost = linear_cost + reg * nx.kl_div(plan, c)
            if reg_m1 != float("inf"):
                total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a)
            if reg_m2 != float("inf"):
                total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b)
            dict_log["total_cost"] = total_cost

            return plan, dict_log
        else:
            return plan




[docs]
def sinkhorn_stabilized_unbalanced(
    a,
    b,
    M,
    reg,
    reg_m,
    reg_type="kl",
    c=None,
    warmstart=None,
    tau=1e5,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""
    Solve the entropic regularization unbalanced optimal transport
    problem and return the loss

    The function solves the following optimization problem using log-domain
    stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`:

    .. math::
        W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) +
        \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - :math:`\mathbf{c}` is a reference distribution for the regularization
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
        If `a` is an empty list or array ([]),
        then `a` is set to uniform distribution.
    b : array-like (dim_b,)
        One or multiple unnormalized histograms of dimension `dim_b`.
        If `b` is an empty list or array ([]),
        then `b` is set to uniform distribution.
        If many, compute all the OT costs :math:`(\mathbf{a}, \mathbf{b}_i)_i`
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float or indexable object of length 1 or 2
        Marginal relaxation term.
        If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1,
        then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations.
        The entropic balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`.
        For semi-relaxed case, use either
        :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or
        :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`.
        If :math:`\mathrm{reg_{m}}` is an array,
        it must have the same backend as input arrays `(a, b, M)`.
    method : str
        method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
        'sinkhorn_reg_scaling', see those function for specific parameters
    reg_type : string, optional
        Regularizer term. Can take two values:
        + Negative entropy: 'entropy':
        :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`.
        This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`.
        + Kullback-Leibler divergence: 'kl':
        :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
    c : array-like (dim_a, dim_b), optional (default=None)
        Reference measure for the regularization.
        If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
        If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the `u`, `v` sinkhorn scaling vectors).
    tau : float
        threshold for max value in `u` or `v` for log scaling
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_cost : (n_hists,) array-like
            the OT cost between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`
    Examples
    --------

    >>> import ot
    >>> import numpy as np
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> np.round(ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.), 7)
    array([[0.3220536, 0.1184769],
           [0.1184769, 0.3220536]])

    .. _references-sinkhorn-stabilized-unbalanced:
    References
    ----------
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT

    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)

    dim_a, dim_b = M.shape

    if len(a) == 0:
        a = nx.ones(dim_a, type_as=M) / dim_a
    if len(b) == 0:
        b = nx.ones(dim_b, type_as=M) / dim_b

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    reg_m1, reg_m2 = get_parameter_pair(reg_m)

    if log:
        dict_log = {"err": []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        if n_hists:
            u = nx.ones((dim_a, n_hists), type_as=M)
            v = nx.ones((dim_b, n_hists), type_as=M)
            a = a.reshape(dim_a, 1)
        else:
            u = nx.ones(dim_a, type_as=M)
            v = nx.ones(dim_b, type_as=M)
    else:
        u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])

    if reg_type == "entropy":
        warnings.warn(
            "If reg_type = entropy, then the matrix c is overwritten by the one matrix."
        )
        c = nx.ones((dim_a, dim_b), type_as=M)

    if n_hists:
        M0 = M
    else:
        c = a[:, None] * b[None, :] if c is None else c
        M0 = M - reg * nx.log(c)
    K = nx.exp(-M0 / reg)

    fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1
    fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1

    cpt = 0
    err = 1.0
    alpha = nx.zeros(dim_a, type_as=M)
    beta = nx.zeros(dim_b, type_as=M)
    ones_a = nx.ones(dim_a, type_as=M)
    ones_b = nx.ones(dim_b, type_as=M)

    while err > stopThr and cpt < numItermax:
        uprev = u
        vprev = v

        Kv = nx.dot(K, v)
        f_alpha = nx.exp(-alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a
        f_beta = nx.exp(-beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b

        if n_hists:
            f_alpha = f_alpha[:, None]
            f_beta = f_beta[:, None]
        u = ((a / (Kv + 1e-16)) ** fi_1) * f_alpha
        Ktu = nx.dot(K.T, u)
        v = ((b / (Ktu + 1e-16)) ** fi_2) * f_beta
        absorbing = False
        if nx.any(u > tau) or nx.any(v > tau):
            absorbing = True
            if n_hists:
                alpha = alpha + reg * nx.log(nx.max(u, 1))
                beta = beta + reg * nx.log(nx.max(v, 1))
            else:
                alpha = alpha + reg * nx.log(nx.max(u))
                beta = beta + reg * nx.log(nx.max(v))
            K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg)
            v = nx.ones(v.shape, type_as=v)
        Kv = nx.dot(K, v)

        if (
            nx.any(Ktu == 0.0)
            or nx.any(nx.isnan(u))
            or nx.any(nx.isnan(v))
            or nx.any(nx.isinf(u))
            or nx.any(nx.isinf(v))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %s" % cpt)
            u = uprev
            v = vprev
            break
        if (cpt % 10 == 0 and not absorbing) or cpt == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            err = nx.max(nx.abs(u - uprev)) / max(
                nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0
            )
            if log:
                dict_log["err"].append(err)
            if verbose:
                if cpt % 200 == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(cpt, err))
        cpt = cpt + 1

    if err > stopThr:
        warnings.warn(
            "Stabilized Unbalanced Sinkhorn did not converge."
            + "Try a larger entropy `reg` or a lower mass `reg_m`."
            + "Or a larger absorption threshold `tau`."
        )
    if n_hists:
        logu = alpha[:, None] / reg + nx.log(u)
        logv = beta[:, None] / reg + nx.log(v)
    else:
        logu = alpha / reg + nx.log(u)
        logv = beta / reg + nx.log(v)
    if log:
        dict_log["logu"] = logu
        dict_log["logv"] = logv
    if n_hists:  # return only loss
        res = nx.logsumexp(
            nx.log(M + 1e-100)[:, :, None]
            + logu[:, None, :]
            + logv[None, :, :]
            - M0[:, :, None] / reg,
            axis=(0, 1),
        )
        res = nx.exp(res)
        if log:
            return res, dict_log
        else:
            return res

    else:  # return OT matrix
        plan = nx.exp(logu[:, None] + logv[None, :] - M0 / reg)
        if log:
            linear_cost = nx.sum(plan * M)
            dict_log["cost"] = linear_cost

            total_cost = linear_cost + reg * nx.kl_div(plan, c)
            if reg_m1 != float("inf"):
                total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a)
            if reg_m2 != float("inf"):
                total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b)
            dict_log["total_cost"] = total_cost

            return plan, dict_log
        else:
            return plan




[docs]
def sinkhorn_unbalanced_translation_invariant(
    a,
    b,
    M,
    reg,
    reg_m,
    reg_type="kl",
    c=None,
    warmstart=None,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""
    Solve the entropic regularization unbalanced optimal transport problem and
    return the OT plan

    The function solves the following optimization problem:

    .. math::
        W = \arg \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg} \cdot \mathrm{KL}(\gamma, \mathbf{c}) +
        \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term,KL divergence
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the translation invariant Sinkhorn algorithm as proposed in :ref:`[73] <references-sinkhorn-unbalanced-translation-invariant>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
    b : array-like (dim_b,) or array-like (dim_b, n_hists)
        One or multiple unnormalized histograms of dimension `dim_b`
        If many, compute all the OT distances (a, b_i)
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float or indexable object of length 1 or 2
        Marginal relaxation term.
        If reg_m is a scalar or an indexable object of length 1,
        then the same reg_m is applied to both marginal relaxations.
        The entropic balanced OT can be recovered using `reg_m=float("inf")`.
        For semi-relaxed case, use either
        `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
        If reg_m is an array, it must have the same backend as input arrays (a, b, M).
    reg_type : string, optional
        Regularizer term. Can take two values:
        'entropy' (negative entropy)
        :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or
        'kl' (Kullback-Leibler)
        :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
    c : array-like (dim_a, dim_b), optional (default=None)
        Reference measure for the regularization.
        If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
        If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors).
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_distance : (n_hists,) array-like
            the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, 1., 1.)
    array([[0.32205357, 0.11847689],
           [0.11847689, 0.32205357]])

    .. _references-sinkhorn-unbalanced-translation-invariant:
    References
    ----------
    .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
       Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
       In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
    """

    M, a, b = list_to_array(M, a, b)
    nx = get_backend(M, a, b)

    dim_a, dim_b = M.shape

    if len(a) == 0:
        a = nx.ones(dim_a, type_as=M) / dim_a
    if len(b) == 0:
        b = nx.ones(dim_b, type_as=M) / dim_b

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    reg_m1, reg_m2 = get_parameter_pair(reg_m)

    if log:
        dict_log = {"err": []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        if n_hists:
            u = nx.ones((dim_a, 1), type_as=M)
            v = nx.ones((dim_b, n_hists), type_as=M)
            a = a.reshape(dim_a, 1)
        else:
            u = nx.ones(dim_a, type_as=M)
            v = nx.ones(dim_b, type_as=M)
    else:
        u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])

    u_, v_ = u, v

    if reg_type == "entropy":
        warnings.warn(
            "If reg_type = entropy, then the matrix c is overwritten by the one matrix."
        )
        c = nx.ones((dim_a, dim_b), type_as=M)

    if n_hists:
        M0 = M
    else:
        c = a[:, None] * b[None, :] if c is None else c
        M0 = M - reg * nx.log(c)
    K = nx.exp(-M0 / reg)

    fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1
    fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1

    k1 = (
        reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2))
        if reg_m1 != float("inf")
        else 0
    )
    k2 = (
        reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2))
        if reg_m2 != float("inf")
        else 0
    )

    k_rho1 = k1 * reg_m1 / reg if reg_m1 != float("inf") else 0
    k_rho2 = k2 * reg_m2 / reg if reg_m2 != float("inf") else 0

    if reg_m1 == float("inf") and reg_m2 == float("inf"):
        xi1, xi2 = 0, 0
        fi_12 = 1
    elif reg_m1 == float("inf"):
        xi1 = 0
        xi2 = reg / reg_m2
        fi_12 = reg_m2
    elif reg_m2 == float("inf"):
        xi1 = reg / reg_m1
        xi2 = 0
        fi_12 = reg_m1
    else:
        xi1 = (reg_m2 * reg) / (reg_m1 * (reg + reg_m1 + reg_m2))
        xi2 = (reg_m1 * reg) / (reg_m2 * (reg + reg_m1 + reg_m2))
        fi_12 = reg_m1 * reg_m2 / (reg_m1 + reg_m2)

    xi_rho1 = xi1 * reg_m1 / reg if reg_m1 != float("inf") else 0
    xi_rho2 = xi2 * reg_m2 / reg if reg_m2 != float("inf") else 0

    reg_ratio1 = reg / reg_m1 if reg_m1 != float("inf") else 0
    reg_ratio2 = reg / reg_m2 if reg_m2 != float("inf") else 0

    err = 1.0

    for i in range(numItermax):
        uprev = u
        vprev = v

        Kv = nx.dot(K, v_)
        u_hat = (a / Kv) ** fi_1 * nx.sum(b * v_**reg_ratio2) ** k_rho2
        u_ = u_hat * nx.sum(a * u_hat ** (-reg_ratio1)) ** (-xi_rho1)

        Ktu = nx.dot(K.T, u_)
        v_hat = (b / Ktu) ** fi_2 * nx.sum(a * u_ ** (-reg_ratio1)) ** k_rho1
        v_ = v_hat * nx.sum(b * v_hat ** (-reg_ratio2)) ** (-xi_rho2)

        if (
            nx.any(Ktu == 0.0)
            or nx.any(nx.isnan(u_))
            or nx.any(nx.isnan(v_))
            or nx.any(nx.isinf(u_))
            or nx.any(nx.isinf(v_))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %s" % i)
            u = uprev
            v = vprev
            break

        t = (nx.sum(a * u_ ** (-reg_ratio1)) / nx.sum(b * v_ ** (-reg_ratio2))) ** (
            fi_12 / reg
        )

        u = u_ * t
        v = v_ / t

        err_u = nx.max(nx.abs(u - uprev)) / max(
            nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0
        )
        err_v = nx.max(nx.abs(v - vprev)) / max(
            nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0
        )
        err = 0.5 * (err_u + err_v)
        if log:
            dict_log["err"].append(err)
            if verbose:
                if i % 50 == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(i, err))

        if err < stopThr:
            break

    if log:
        dict_log["logu"] = nx.log(u + 1e-300)
        dict_log["logv"] = nx.log(v + 1e-300)

    if n_hists:  # return only loss
        res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M)
        if log:
            return res, dict_log
        else:
            return res

    else:  # return OT matrix
        plan = u[:, None] * K * v[None, :]

        if log:
            linear_cost = nx.sum(plan * M)
            dict_log["cost"] = linear_cost

            total_cost = linear_cost + reg * nx.kl_div(plan, c)
            if reg_m1 != float("inf"):
                total_cost = total_cost + reg_m1 * nx.kl_div(nx.sum(plan, 1), a)
            if reg_m2 != float("inf"):
                total_cost = total_cost + reg_m2 * nx.kl_div(nx.sum(plan, 0), b)
            dict_log["total_cost"] = total_cost

            return plan, dict_log
        else:
            return plan




[docs]
def barycenter_unbalanced_stabilized(
    A,
    M,
    reg,
    reg_m,
    weights=None,
    tau=1e3,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
):
    r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization.

     The function solves the following optimization problem:

    .. math::
       \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)

    where :

    - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
    - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
    - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
    - reg_mis the marginal relaxation hyperparameter

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>`

    Parameters
    ----------
    A : array-like (dim, n_hists)
        `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
    M : array-like (dim, dim)
        ground metric matrix for OT.
    reg : float
        Entropy regularization term > 0
    reg_m : float
        Marginal relaxation term > 0
    tau : float
        Stabilization threshold for log domain absorption.
    weights : array-like (n_hists,) optional
        Weight of each distribution (barycentric coordinates)
        If None, uniform weights are used.
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    a : (dim,) array-like
        Unbalanced Wasserstein barycenter
    log : dict
        log dictionary return only if :math:`log==True` in parameters


    .. _references-barycenter-unbalanced-stabilized:
    References
    ----------
    .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
        G. (2015). Iterative Bregman projections for regularized transportation
        problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.


    """
    A, M = list_to_array(A, M)
    nx = get_backend(A, M)

    dim, n_hists = A.shape
    if weights is None:
        weights = nx.ones(n_hists, type_as=A) / n_hists
    else:
        assert len(weights) == A.shape[1]

    if log:
        log = {"err": []}

    fi = reg_m / (reg_m + reg)

    u = nx.ones((dim, n_hists), type_as=A) / dim
    v = nx.ones((dim, n_hists), type_as=A) / dim

    # print(reg)
    K = nx.exp(-M / reg)

    fi = reg_m / (reg_m + reg)

    cpt = 0
    err = 1.0
    alpha = nx.zeros(dim, type_as=A)
    beta = nx.zeros(dim, type_as=A)
    q = nx.ones(dim, type_as=A) / dim
    for i in range(numItermax):
        qprev = nx.copy(q)
        Kv = nx.dot(K, v)
        f_alpha = nx.exp(-alpha / (reg + reg_m))
        f_beta = nx.exp(-beta / (reg + reg_m))
        f_alpha = f_alpha[:, None]
        f_beta = f_beta[:, None]
        u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
        Ktu = nx.dot(K.T, u)
        q = (Ktu ** (1 - fi)) * f_beta
        q = nx.dot(q, weights) ** (1 / (1 - fi))
        Q = q[:, None]
        v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
        absorbing = False
        if nx.any(u > tau) or nx.any(v > tau):
            absorbing = True
            alpha = alpha + reg * nx.log(nx.max(u, 1))
            beta = beta + reg * nx.log(nx.max(v, 1))
            K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
            v = nx.ones(v.shape, type_as=v)
        Kv = nx.dot(K, v)
        if (
            nx.any(Ktu == 0.0)
            or nx.any(nx.isnan(u))
            or nx.any(nx.isnan(v))
            or nx.any(nx.isinf(u))
            or nx.any(nx.isinf(v))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %s" % cpt)
            q = qprev
            break
        if (i % 10 == 0 and not absorbing) or i == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            err = nx.max(nx.abs(q - qprev)) / max(
                nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
            )
            if log:
                log["err"].append(err)
            if verbose:
                if i % 50 == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(i, err))
            if err < stopThr:
                break

    if err > stopThr:
        warnings.warn(
            "Stabilized Unbalanced Sinkhorn did not converge."
            + "Try a larger entropy `reg` or a lower mass `reg_m`."
            + "Or a larger absorption threshold `tau`."
        )
    if log:
        log["niter"] = i
        log["logu"] = nx.log(u + 1e-300)
        log["logv"] = nx.log(v + 1e-300)
        return q, log
    else:
        return q




[docs]
def barycenter_unbalanced_sinkhorn(
    A,
    M,
    reg,
    reg_m,
    weights=None,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
):
    r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.

     The function solves the following optimization problem with :math:`\mathbf{a}`

    .. math::
       \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)

    where :

    - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
    - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
    - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
    - reg_mis the marginal relaxation hyperparameter

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>`

    Parameters
    ----------
    A : array-like (dim, n_hists)
        `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
    M : array-like (dim, dim)
        ground metric matrix for OT.
    reg : float
        Entropy regularization term > 0
    reg_m: float
        Marginal relaxation term > 0
    weights : array-like (n_hists,) optional
        Weight of each distribution (barycentric coordinates)
        If None, uniform weights are used.
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record `log` if `True`


    Returns
    -------
    a : (dim,) array-like
        Unbalanced Wasserstein barycenter
    log : dict
        log dictionary return only if :math:`log==True` in parameters


    .. _references-barycenter-unbalanced-sinkhorn:
    References
    ----------
    .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
        (2015). Iterative Bregman projections for regularized transportation
        problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprin
        arXiv:1607.05816.


    """
    A, M = list_to_array(A, M)
    nx = get_backend(A, M)

    dim, n_hists = A.shape
    if weights is None:
        weights = nx.ones(n_hists, type_as=A) / n_hists
    else:
        assert len(weights) == A.shape[1]

    if log:
        log = {"err": []}

    K = nx.exp(-M / reg)

    fi = reg_m / (reg_m + reg)

    v = nx.ones((dim, n_hists), type_as=A)
    u = nx.ones((dim, 1), type_as=A)
    q = nx.ones(dim, type_as=A)
    err = 1.0

    for i in range(numItermax):
        uprev = nx.copy(u)
        vprev = nx.copy(v)
        qprev = nx.copy(q)

        Kv = nx.dot(K, v)
        u = (A / Kv) ** fi
        Ktu = nx.dot(K.T, u)
        q = nx.dot(Ktu ** (1 - fi), weights)
        q = q ** (1 / (1 - fi))
        Q = q[:, None]
        v = (Q / Ktu) ** fi

        if (
            nx.any(Ktu == 0.0)
            or nx.any(nx.isnan(u))
            or nx.any(nx.isnan(v))
            or nx.any(nx.isinf(u))
            or nx.any(nx.isinf(v))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %s" % i)
            u = uprev
            v = vprev
            q = qprev
            break
            # compute change in barycenter
        err = nx.max(nx.abs(q - qprev)) / max(
            nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
        )
        if log:
            log["err"].append(err)
        # if barycenter did not change + at least 10 iterations - stop
        if err < stopThr and i > 10:
            break

        if verbose:
            if i % 10 == 0:
                print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
            print("{:5d}|{:8e}|".format(i, err))

    if log:
        log["niter"] = i
        log["logu"] = nx.log(u + 1e-300)
        log["logv"] = nx.log(v + 1e-300)
        return q, log
    else:
        return q




[docs]
def barycenter_unbalanced(
    A,
    M,
    reg,
    reg_m,
    method="sinkhorn",
    weights=None,
    numItermax=1000,
    stopThr=1e-6,
    verbose=False,
    log=False,
    **kwargs,
):
    r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.

     The function solves the following optimization problem with :math:`\mathbf{a}`

    .. math::
       \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)

    where :

    - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
    - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
    - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
    - reg_mis the marginal relaxation hyperparameter

    The algorithm used for solving the problem is the generalized
    Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>`

    Parameters
    ----------
    A : array-like (dim, n_hists)
        `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
    M : array-like (dim, dim)
        ground metric matrix for OT.
    reg : float
        Entropy regularization term > 0
    reg_m: float
        Marginal relaxation term > 0
    weights : array-like (n_hists,) optional
        Weight of each distribution (barycentric coordinates)
        If None, uniform weights are used.
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    a : (dim,) array-like
        Unbalanced Wasserstein barycenter
    log : dict
        log dictionary return only if log==True in parameters


    .. _references-barycenter-unbalanced:
    References
    ----------
    .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
        (2015). Iterative Bregman projections for regularized transportation
        problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprin
        arXiv:1607.05816.

    """

    if method.lower() == "sinkhorn":
        return barycenter_unbalanced_sinkhorn(
            A,
            M,
            reg,
            reg_m,
            weights=weights,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )

    elif method.lower() == "sinkhorn_stabilized":
        return barycenter_unbalanced_stabilized(
            A,
            M,
            reg,
            reg_m,
            weights=weights,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )
    elif method.lower() in ["sinkhorn_reg_scaling", "sinkhorn_translation_invariant"]:
        warnings.warn("Method not implemented yet. Using classic Sinkhorn Knopp")
        return barycenter_unbalanced(
            A,
            M,
            reg,
            reg_m,
            weights=weights,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            **kwargs,
        )
    else:
        raise ValueError("Unknown method '%s'." % method)


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4