A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://pythonot.github.io/auto_examples/../gen_modules/../gen_modules/../_modules/ot/lp/dmmot.html below:

Website Navigation


ot.lp.dmmot — POT Python Optimal Transport 0.9.5 documentation

# -*- coding: utf-8 -*-
"""
d-MMOT solvers for optimal transport
"""

# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
#         Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License

import numpy as np
from ..backend import get_backend


def dist_monge_max_min(i):
    r"""
    A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and
    :math:j_1, \ldots, j_d,

    .. math::
        c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) +
        c(j_1, \ldots, j_d)

    where :math:s_k = \min(i_k, j_k) and :math:t_k = \max(i_k, j_k).

    Our focus is on a specific cost, which is known to be Monge:

    .. math::
        c(i_1,i_2,\ldots,i_d) = \max{i_k:k\in[d]} - \min{i_k:k\in[d]}.

    When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|,
    which agrees with the classical EMD cost. This choice of :math:c is called
    the generalized EMD cost.

    Parameters
    ----------
    i : list
        The list of integer indexes.

    Returns
    -------
    cost : numeric value
        The ground cost (generalized EMD cost) of the tensor.

    References
    ----------
    .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's
        problem. Discrete Applied Mathematics, 265: 128-141, 2019.
    .. [57] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K.
        Pathak. A monge property for the d-dimensional transportation problem.
        Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi:
        https://doi.org/10.1016/0166-218X(93)E0121-E. URL
        https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E.
        Workshop on Discrete Algorithms.
    """

    return max(i) - min(i)



[docs]
def dmmot_monge_1dgrid_loss(A, verbose=False, log=False):
    r"""
    Compute the discrete multi-marginal optimal transport of distributions A.

    This function operates on distributions whose supports are real numbers on
    the real line.

    The algorithm solves both primal and dual d-MMOT programs concurrently to
    produce the optimal transport plan as well as the total (minimal) cost.
    The cost is a ground cost, and the solution is independent of
    which Monge cost is desired.

    The algorithm accepts :math:`d` distributions (i.e., histograms)
    :math:`a_{1}, \ldots, a_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime}
    a_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all
    histograms have the same number of bins, the algorithm can be easily
    adapted to accept as inputs :math:`a_{i} \in \mathbb{R}_{+}^{n_{i}}`
    with :math:`n_{i} \neq n_{j}` [50].

    The function solves the following optimization problem[51]:

    .. math::
        \begin{align}\begin{aligned}
            \underset{\gamma\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}}
            \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, \gamma(i_1,\ldots,i_d)
            \quad \textrm{s.t.}
            \sum_{i_2,\ldots,i_d} \gamma(i_1,\ldots,i_d) &= a_1(i_i),
            (\forall i_1\in[n])\\
            \qquad\vdots\\
            \sum_{i_1,\ldots,i_{d-1}} \gamma(i_1,\ldots,i_d) &= a_{d}(i_{d}),
            (\forall i_d\in[n]).
            \end{aligned}
        \end{align}


    Parameters
    ----------
    A : nx.ndarray, shape (dim, n_hists)
        The input ndarray containing distributions of n bins in d dimensions.
    verbose : bool, optional
        If True, print debugging information during execution. Default=False.
    log : bool, optional
        If True, record log. Default is False.

    Returns
    -------
    obj : float
        the value of the primal objective function evaluated at the solution.
    log : dict
        A dictionary containing the log of the discrete mmot problem:
        - 'A': a dictionary that maps tuples of indices to the corresponding
        primal variables. The tuples are the indices of the entries that are
        set to their minimum value during the algorithm.
        - 'primal objective': a float, the value of the objective function
        evaluated at the solution.
        - 'dual': a list of arrays, the dual variables corresponding to
        the input arrays. The i-th element of the list is the dual variable
        corresponding to the i-th dimension of the input arrays.
        - 'dual objective': a float, the value of the dual objective function
        evaluated at the solution.


    References
    ----------
    .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, &
        Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal
        Transport Regularization. In The Eleventh International
        Conference on Learning Representations.
    .. [56] Jeffery Kline. Properties of the d-dimensional earth mover's
        problem. Discrete Applied Mathematics, 265: 128-141, 2019.
    .. [58] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad.
        Nauk SSSR, 37:227-229, 1942.

    See Also
    --------
    ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth
    Mover's Distance (d-MMOT)
    """

    nx = get_backend(A)
    A_copy = A
    A = nx.to_numpy(A)

    AA = [np.copy(A[:, j]) for j in range(A.shape[1])]

    dims = tuple([len(_) for _ in AA])
    xx = {}
    dual = [np.zeros(d) for d in dims]

    idx = [
        0,
    ] * len(AA)
    obj = 0

    if verbose:
        print("i minval oldidx\t\tobj\t\tvals")

    while all([i < _ for _, i in zip(dims, idx)]):
        vals = [v[i] for v, i in zip(AA, idx)]
        minval = min(vals)
        i = vals.index(minval)
        xx[tuple(idx)] = minval
        obj += (dist_monge_max_min(idx)) * minval
        for v, j in zip(AA, idx):
            v[j] -= minval
        # oldidx = nx.copy(idx)
        oldidx = idx.copy()
        idx[i] += 1
        if idx[i] < dims[i]:
            temp = (
                dist_monge_max_min(idx)
                - dist_monge_max_min(oldidx)
                + dual[i][idx[i] - 1]
            )
            dual[i][idx[i]] += temp
        if verbose:
            print(i, minval, oldidx, obj, "\t", vals)

    # the above terminates when any entry in idx equals the corresponding
    # value in dims this leaves other dimensions incomplete; the remaining
    # terms of the dual solution must be filled-in
    for _, i in enumerate(idx):
        try:
            dual[_][i:] = dual[_][i]
        except Exception:
            pass

    dualobj = sum([np.dot(A[:, i], arr) for i, arr in enumerate(dual)])
    obj = nx.from_numpy(obj)

    log_dict = {
        "A": xx,
        "primal objective": obj,
        "dual": dual,
        "dual objective": dualobj,
    }

    # define forward/backward relations for pytorch
    obj = nx.set_gradients(obj, (A_copy), (dual))

    if log:
        return obj, log_dict
    else:
        return obj




[docs]
def dmmot_monge_1dgrid_optimize(
    A,
    niters=100,
    lr_init=1e-5,
    lr_decay=0.995,
    print_rate=100,
    verbose=False,
    log=False,
):
    r"""Minimize the d-dimensional EMD using gradient descent.

    Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots,
    a_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here,
    the d-MMOT is the LP,

    .. math::
        \begin{align}\begin{aligned}
            \underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}}
            \sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad
            \textrm{s.t.}
            \sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= a_1(i_i),
            (\forall i_1\in[n])\\
            \qquad\vdots\\
            \sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= a_{d}(i_{d}),
            (\forall i_d\in[n]).
            \end{aligned}
        \end{align}

    The dual linear program of the d-MMOT problem is:

    .. math::
        \underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j}
        a_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d})
        \leq c(i_1,\ldots,i_{d}),


    where the indices in the constraints include all :math:`i_j\in[n]`, :math:
    `j\in[d]`. Denote by :math:`\phi(a_1,\ldots,a_d)`, the optimal objective
    value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution
    to the dual program. Then,

    .. math::
        \begin{align}
            \nabla \phi(a_1,\ldots,a_{d}) &= z^*,
            ~~\text{and for any $t\in \mathbb{R}$,}~~
            \phi(a_1,a_2,\ldots,a_{d}) = \sum_{j}a_j'
            (z_j^* + t\, \eta), \nonumber \\
            \text{where } \eta &:= (z_1^{*}(n)\,e, z^*_1(n)\,e, \cdots,
            z^*_{d}(n)\,e)
        \end{align}

    Using these dual variables naturally provided by the algorithm in
    ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution
    to minimize their d-mmot distance.

    Parameters
    ----------
    A : nx.ndarray, shape (dim, n_hists)
        The input ndarray containing distributions of n bins in d dimensions.
    niters : int, optional (default=100)
        The maximum number of iterations for the optimization algorithm.
    lr_init : float, optional (default=1e-5)
        The initial learning rate (step size) for the optimization algorithm.
    lr_decay : float, optional (default=0.995)
        The learning rate decay rate in each iteration.
    print_rate : int, optional (default=100)
        The rate at which to print the objective value and gradient norm
        during the optimization algorithm.
    verbose : bool, optional
        If True, print debugging information during execution. Default=False.
    log : bool, optional
        If True, record log. Default is False.

    Returns
    -------
    a : list of ndarrays, each of shape (n,)
        The optimal solution as a list of n approximate barycenters, each of
        length vecsize.
    log : dict
        log dictionary return only if log==True in parameters

    References
    ----------
    .. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, &
        Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal
        Transport Regularization. In The Eleventh International
        Conference on Learning Representations.
    .. [60] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear
        programs. SIAM Journal on Control and Optimization, 17(6):745-752, 1979
    .. [59] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of
        convex programs. Applied Mathematics and Optimization, 23(1):263-273,
        1991.

    See Also
    --------
    ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver
    """

    # function body here
    nx = get_backend(A)
    A = nx.to_numpy(A)
    n, d = A.shape  # n is dim, d is n_hists

    def dualIter(A, lr):
        funcval, log_dict = dmmot_monge_1dgrid_loss(A, verbose=verbose, log=True)
        grad = np.column_stack(log_dict["dual"])
        A_new = np.reshape(A, (n, d)) - grad * lr
        return funcval, A_new, grad, log_dict

    def renormalize(A):
        A = np.reshape(A, (n, d))
        for i in range(A.shape[1]):
            if min(A[:, i]) < 0:
                A[:, i] -= min(A[:, i])
            A[:, i] /= np.sum(A[:, i])
        return A

    def listify(A):
        return [A[:, i] for i in range(A.shape[1])]

    lr = lr_init

    funcval, _, grad, log_dict = dualIter(A, lr)
    gn = np.linalg.norm(grad)

    print(f"Initial:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}")

    for i in range(niters):
        A = renormalize(A)
        funcval, A, grad, log_dict = dualIter(A, lr)
        gn = np.linalg.norm(grad)

        if i % print_rate == 0:
            print(f"Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}")

        lr *= lr_decay

    A = renormalize(A)
    a = listify(A)

    if log:
        return a, log_dict
    else:
        return a


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4