Gradient flows between sampled measures

(Author: Jean Feydy)

In this notebook, we showcase the properties of several geometric divergences defined on the space of probability measures: Kernel Norms (aka. Maximum Mean Discrepancies), Maximum Likelihoods of Gaussian Mixture Models (aka. sum-Hausdorff distances) and Optimal Transport costs (aka. Wasserstein or Earth-Mover's distances).

In [1]:
# Import the standard array-related libraries (MATLAB-like)
import numpy as np
import matplotlib.pyplot as plt
import display # narrow jupyter column

%matplotlib inline
In [2]:
# Import the automatic differentiation + GPU toolbox
import torch
use_cuda = torch.cuda.is_available() # Shall we use the GPU?
tensor   = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
dtype    = tensor

# Let's keep things fast by default. Feel free to increase!
NPOINTS = 1000 if use_cuda else 200
numpy = lambda x : x.detach().cpu().numpy()

To keep things simple and allow us to assess graphically the performances of our methods, we will work with measures $\alpha$ and $\beta$ sampled on the unit square:

\begin{align} \alpha &~=~ \sum_{i=1}^\text{N} \alpha_i\delta_{x_i}, & \beta &~=~ \sum_{j=1}^\text{M} \beta_j\delta_{y_j}, \end{align}

where $\alpha_i$, $\beta_j$ are positive weights associated to the samples $x_i$ and $y_j$ in $\mathbb{R}^2$. In this notebook, we will focus on the case where $\alpha$ and $\beta$ are probability measures:

\begin{align} \sum_{i=1}^\text{N} \alpha_i~=~1~=~ \sum_{j=1}^\text{M} \beta_j. \end{align}

In [3]:
from sampling import draw_samples, display_samples

# α and β are sampled from two png densities
α_i, x_i = draw_samples("data/density_a.png", NPOINTS, dtype)
β_j, y_j = draw_samples("data/density_b.png", NPOINTS, dtype)
In [4]:
plt.figure(figsize=(7,7))
plt.scatter( [10], [10] ) # shameless hack to prevent change of axis

display_samples(plt.gca(), y_j, (.55,.55,.95))
display_samples(plt.gca(), x_i, (.95,.55,.55))

plt.axis("equal")
plt.axis([0,1,0,1])
plt.gca().set_aspect('equal', adjustable='box')

Gradient flows. This notebook is all about studying Cost functions that have distance-like properties on the space of probability measures. A simple way of highlighting the geometry induced by such functionals is to follow their Wasserstein gradient flows, i.e. to integrate the ODE

\begin{align} \dot{x}_i(t)~=~-\tfrac{1}{\alpha_i}\, \nabla_{x_i} \text{Cost}\big(\sum_i \alpha_i \delta_{x_i(t)}, \beta\big) \end{align}

starting from an initial condition $x_i(t=0) = x_i$, performing a weighted gradient descent on the function

\begin{align} \text{Cost}_{\beta}~:~(x_i)\in\mathbb{R}^{\text{N}\cdot d}~\mapsto~ \text{Cost}\big(\sum_i \alpha_i \delta_{x_i}, \beta\big). \end{align}

In [5]:
def gradient_flow(α_i, x_i, β_j, y_j, cost, lr=.05) :
    """
    Flows along the gradient of the cost function, using a simple Euler scheme.
    
    Parameters
    ----------
        α_i : (N,1) torch tensor
            weights of the source measure
        x_i : (N,2) torch tensor
            samples of the source measure
        β_j : (M,1) torch tensor
            weights of the target measure
        y_j : (M,2) torch tensor
            samples of the target measure
        cost : (α_i,x_i,β_j,y_j) -> torch float number,
            real-valued function
        lr : float, default = .05
            learning rate, i.e. time step
    """
    
    # Parameters for the gradient descent
    Nsteps = int(5/lr)+1 
    t_plot      = np.linspace(-0.1, 1.1, 1000)[:,np.newaxis]
    display_its = [int(t/lr) for t in [0, .25, .50, 1., 2., 5.]]
    
    # Make sure that we won't modify the input measures
    α_i, x_i, β_j, y_j = α_i.clone(), x_i.clone(), \
                         β_j.clone(), y_j.clone()

    # We're going to perform gradient descent on Cost(Alpha, Beta) 
    # wrt. the positions x_i of the diracs masses that make up Alpha:
    x_i.requires_grad_(True)  
    
    plt.figure(figsize=(12,8)) ; k = 1
    for i in range(Nsteps): # Euler scheme ===============
        # Compute cost and gradient
        loss = cost(α_i, x_i, β_j, y_j)
        [g]  = torch.autograd.grad(loss, [x_i])

        if i in display_its : # display
            ax = plt.subplot(2,3,k) ; k = k+1
            ax.scatter( [10], [10] ) # shameless hack

            display_samples(ax, y_j, (.55,.55,.95))
            display_samples(ax, x_i, (.95,.55,.55), 
                g/α_i, width=.25/len(x_i), scale=5)
            
            ax.set_title("t = {:1.2f}".format(lr*i))
            ax.axis("equal") ; ax.axis([0,1,0,1])
            plt.xticks([], []); plt.yticks([], [])
            ax.set_aspect('equal', adjustable='box')
            
        # in-place modification of the tensor's values
        x_i.data -= lr * (g / α_i)

This evolution can be understood as an ideal, model-free machine learning problem where a source distribution $\alpha_t$ is iteratively fitted towards a target (empirical) distribution.

Let us now display the evolution associated to the quadratic spring energy between labeled point clouds.

In [6]:
def L2_cost(α_i, x_i, β_j, y_j) :
    """
    Simplistic L2 cost (aka. spring energy) between sampled point clouds,
    assuming a pairwise correspondence between x_i[k] and y_j[k].
    """
    return .5*(α_i*((x_i-y_j)**2).sum(1,keepdim=True)).sum()

gradient_flow(α_i, x_i, β_j, y_j, L2_cost)

It works! Now, let's move on to costs that are well-defined between unlabeled point clouds with, possibly, different weights and numbers of samples.

A computational building block: the kernel product

Most standard costs between sampled measures can be computed using a kernel product operator

$$ \text{KP} :~ \big((x_i), (y_j), (\beta_j)\big) \in \mathbb{R}^{\text{N}\cdot d}\times \mathbb{R}^{\text{M}\cdot d} \times \mathbb{R}^{\text{M}\cdot 1} ~~ \mapsto ~~ \bigg( \sum_j k(x_i-y_j)\,\beta_j \bigg)_i \in \mathbb{R}^{\text{N}\cdot 1}$$

where $k:\mathbb{R}^d \rightarrow \mathbb{R}$ is a convolution kernel. Mathematically, this operation is known as a discrete convolution: Indeed, if $\beta = \sum_j \beta_j \delta_{y_j}$ is a discrete measure, the convolution product $k\star \beta$ is a function defined on $\mathbb{R}^d$ by

$$\big(k\star\beta \big)(x) ~=~ \sum_j k(x-y_j) \,\beta_j,$$

so that computing the kernel product $\text{KP}\big((x_i), (y_j), (\beta_j)\big)$ is equivalent to computing and sampling $k\star \beta$ on the point cloud $(x_i)$.

In [7]:
def KP(x,y,β_j, kernel = "gaussian", s = 1.) :
    """
    Computes K(x_i,y_j) @ β_j = \sum_j k(x_i-y_j) * β_j
    where k is a kernel function (say, a Gaussian) of deviation s.
    """
    x_i = x[:,None,:]  # Shape (N,d) -> Shape (N,1,d)
    y_j = y[None,:,:]  # Shape (M,d) -> Shape (1,M,d)
    xmy = x_i - y_j    # (N,M,d) matrix, xmy[i,j,k] = (x_i[k]-y_j[k])
    if   kernel == "gaussian" : K = torch.exp( - (xmy**2).sum(2) / (2*(s**2)) )
    elif kernel == "laplace"  : K = torch.exp( - xmy.norm(dim=2) / s )
    elif kernel == "energy"   : K = - xmy.norm(dim=2)
    return K @ β_j.view(-1,1) # Matrix-vector product

Using a kernel norm

Total Variation: a first dual norm. Now, which cost function $\text{Cost}(\alpha_t, \beta)$ are we going to choose to drive our simple optimization routine? Given two measures $\alpha$ and $\beta$ on $\mathbb{R}^d$, one of the simplest distance that can be defined is the Total Variation

$$\text{d}_{\text{TV}}(\alpha,\beta) ~=~ \|\alpha-\beta\|_{\infty}^{\star} ~=~ \sup_{\|f\|_{\infty} \leqslant 1} \int f \text{d}\alpha - \int f \text{d}\beta,$$

using the dual norm on $L^{\infty}(\mathbb{R}^d, \mathbb{R})$. Unfortunately, this formula is not suited at all to sampled, discrete probability measures with non-overlapping support: If $\alpha = \sum_i \alpha_i\,\delta_{x_i}$ and $\beta = \sum_j \beta_j\,\delta_{y_j}$ with $\{x_i, \dots\}\cap\{y_j,\dots\} = \emptyset$, one can simply choose a function $f$ such that

$$\forall \, i,~ f(x_i) ~=~+1 ~~~ \text{and} ~~~ \forall \, j, ~f(y_j) ~=~-1$$

to show that

$$\text{d}_{\text{TV}}(\alpha, \beta) ~=~ |\alpha| + |\beta| ~=~ 2 ~~~~ \text{as soon as $\text{supp}(\alpha)$ and $\text{supp}(\beta)$ do not overlap.}$$

The gradient of the Total Variation distance between two sampled measures is thus completely uninformative, being zero for almost all configurations.

Smoothing measures to create overlap. How can we fix this problem? An idea would be to choose a blurring function $g$, and compare the blurred functions $g\star \alpha$ and $g\star \beta$ by using, say, an $L^2$ norm:

$$\text{d}(\alpha, \beta) ~=~ \| g\star(\alpha-\beta)\|_2^2 ~=~ \langle g\star(\alpha-\beta), g\star(\alpha-\beta)\rangle_2.$$

But then, if we define $k = \tilde{g}\star g$, where $\tilde{g} = g \circ (x\mapsto -x)$ is the mirrored blurring function, one gets

$$\text{d}_k(\alpha,\beta) ~=~ \langle g\star(\alpha-\beta), g\star(\alpha-\beta)\rangle_2 ~=~ \langle \alpha-\beta, k\star(\alpha-\beta)\rangle ~=~ \|\alpha-\beta\|_k^2.$$

Assuming a few properties on $k$ (detailed below), $\text{d}_k$ is the quadratic norm associated with the $k$-scalar product between measures:

$$\langle \alpha, \beta \rangle_k ~=~ \langle \alpha, k\star \beta\rangle.$$

More specifically,

\begin{align} \bigg\langle \sum_i \alpha_i \, \delta_{x_i} , \sum_j \beta_j\,\delta_{y_j} \bigg\rangle_k ~&=~\bigg\langle \sum_i \alpha_i \, \delta_{x_i} , \sum_j \beta_j\,\big(k\star\delta_{y_j}\big) \bigg\rangle \\ ~&=~\bigg\langle \sum_i \alpha_i \, \delta_{x_i} , \sum_j \beta_j\,k(\,\cdot\,- y_j) \bigg\rangle ~=~ \sum_{i,j} k(x_i-y_j) \, \alpha_i \beta_j. \end{align}

In [8]:
# PyTorch syntax for the L2 scalar product...
def scal(α, f) :
    return torch.dot(α.view(-1), f.view(-1))

def kernel_scalar_product(α_i, x_i, β_j, y_j, mode = "gaussian", s = 1.) :
    Kxy_β = KP(x_i,y_j,β_j,mode,s)
    return scal( α_i, Kxy_β ) 

Having defined the scalar product, we then simply develop by bilinearity:

$$\tfrac{1}{2}\|\alpha-\beta\|_k^2 ~=~ \tfrac{1}{2}\langle \alpha,\alpha \rangle_k \, -\,\langle \alpha,\beta \rangle_k \,+\,\tfrac{1}{2}\langle \beta,\beta \rangle_k.$$

In [9]:
def kernel_distance(mode = "gaussian", s = 1.) :
    def cost(α_i, x_i, β_j, y_j) :
        D2 =   (.5*kernel_scalar_product(α_i, x_i, α_i, x_i, mode,s) \
               +.5*kernel_scalar_product(β_j, y_j, β_j, y_j, mode,s) \
               -   kernel_scalar_product(α_i, x_i, β_j, y_j, mode,s) )
        return D2    
    return cost

This formula looks good: points interact with each other as soon as $k(x_i,y_j)$ is non-negligible. But if we want to get a genuine norm between measures, which hypotheses should we make on $k$?

This question was studied by mathematicians from the first half of the 20th century who developed the theory of Reproducing Kernel Hilbert Spaces - RKHS. In our specific translation-invariant case (in which we "hardcode" convolutions), the results can be summed up as follow:

  • Principled kernel norms are the ones associated to kernel functions $k$ whose Fourier transform is real-valued and positive - think, Gaussian kernels:

$$\forall\, \omega \in \mathbb{R}^d, ~ \widehat{k}(\omega) > 0.$$

  • For any such kernel function, there exists a unique blurring kernel function $g$ such that $g\star g = k$: Simply choose

$$\widehat{g}(\omega) ~=~ \sqrt{ \widehat{k}(\omega)}.$$

  • These kernels define a Hilbert norm on a subset of $L^2(\mathbb{R}^d)$:

$$\|f\|_V^2 ~=~ \int_{\omega \in \mathbb{R}^d} \frac{|\widehat{f}(\omega)|^2}{\widehat{k}(\omega)} \,\text{d}\omega ~=~ \langle k^{(-1)} \star f\,,\, f\rangle$$ where $k^{(-1)}$ is the deconvolution kernel associated to $k$. If we define

$$V ~=~ \big\{ f\in L^2(\mathbb{R}^d), ~\|f\|_V < \infty \big\}, $$

then $(V, \|\cdot\|_V)$ is a Hilbert space of functions endowed with the scalar product

$$ \langle f\,,\, g\rangle_V ~=~ \int_{\omega \in \mathbb{R}^d} \frac{\overline{\widehat{f}(\omega)} \,\widehat{g}(\omega)}{\widehat{k}(\omega)} \,\text{d}\omega ~=~ \langle k^{(-1)} \star f\,,\, g\rangle. $$

  • We focus on kernel functions such that for all points $x\in\mathbb{R}^d$, the evaluation at point $x$ is a continuous linear form on $V$. That is,

$$ \delta_x : f\in (V, \|\cdot\|_V) \mapsto f(x) \in (\mathbb{R}, |\cdot|)$$

is well-defined and continuous. A sufficient condition for this is to ask that $\widehat{k} \in L^1(\mathbb{R}^d)$ and continuous. Then, we show that the Riesz theorem identifies $\delta_x$ with the continuous function $k\star \delta_x : y \mapsto k(y-x)$:

$$ \forall\, f\in V,~~ f(x)~=~\langle \delta_x\,,\, f\rangle ~=~ \langle k\star\delta_x\,,\, f\rangle_V.$$

  • Finite sampled measures can thus be identified with linear forms on $V$. The $k$-norm is nothing but the dual norm of $\|\cdot\|_V$:

$$\forall\, \alpha\in V^{\star}, ~\|\alpha\|_k ~=~ \sqrt{\langle \alpha\,,\, k\star \alpha \rangle} ~=~ \sup_{\|f\|_V = 1} \langle \alpha\,,\, f\rangle.$$

All-in-all, just like the TV distance, the kernel distance can be seen as the dual of a norm on a space of functions. Whereas TV was associated to the infinity norm $\|\cdot\|_{\infty}$ on $L^{\infty}(\mathbb{R}^d)$, the kernel formulas are linked to Sobolev-like norms $\|\cdot\|_V$ on spaces of $k$-smooth functions, denoted by the letter $V$.

Exercise 1: Using the method of Lagrange multipliers (aka. théorème des extrema liés in the French curriculum), show the last equality above (kernel norms are dual norms on Hilbert spaces of functions).

Solution 1: We are optimizing the linear form $f\mapsto\langle\alpha,f\rangle$ on the unit $V$-sphere, which is a level set of the function

\begin{align} R(f)~=~\|f\|_V^2~=~\langle f, k^{(-1)}\star f\rangle, ~~~ \text{with gradient}~~~ \nabla R(f)~=~2\cdot k^{(-1)}\star f. \end{align}

At the optimum, we thus get a constant $\lambda\in\mathbb{R}$ such that

\begin{align} \alpha~=~ 2\lambda \cdot k^{(-1)}\star f ~~~~~ \text{i.e.} ~~~~~ f~=~\underbrace{\tfrac{1}{2\lambda}}_\mu k\star \alpha. \end{align}

Then, the equation "$\langle f, k^{(-1)}\star f\rangle =1$" gives $\mu = 1/\sqrt{\langle\alpha,k\star\alpha\rangle}$ and finally

\begin{align} \langle\alpha,f\rangle~=~\mu\,\langle\alpha,k\star\alpha\rangle~=~ \sqrt{\langle\alpha,k\star\alpha\rangle}. \end{align}

Exercise 2: Why can we say that RKHS generalize high-order Sobolev spaces? In dimension 1, what is the functional space associated to the Laplace kernel

\begin{align} k(x,y)~=~ e^{-\|x-y\|}~~ ? \end{align}

Solution 2: $H^s$ Sobolev norms are defined through

\begin{align} \|f\|_{H^s}^2~&=~ \|f\|_{L^2}^2~+~\|f'\|_{L^2}^2~+~\cdots~+~\|f^{(s)}\|_{L^2}^2 \\ &=~ \|\widehat{f}\|_{L^2}^2~+~\|\widehat{f'}\|_{L^2}^2~+~\cdots~+~\|\widehat{f^{(s)}}\|_{L^2}^2 \\ &=~ \int_\omega (1+|\omega|^2+\cdots+|\omega|^{2s})\,|\widehat{f}(\omega)|^2\,\text{d}\omega \\ &=~ \langle f, k^{(-1)}_s\star f\rangle, \end{align}

with $\widehat{k_s}(\omega)~=~1/(1+|\omega|^2+\cdots+|\omega|^{2s})$. Kernel norms allow us to generalize this construction to arbitrary (non-rational) spectral profiles, such as that of the Gaussian kernel. Going further, we could even consider kernels which are not translation-invariant, leaving the comfort of Fourier analysis to handle realistic, inhomogeneous situations.

On a side note: in dimension 1, since the Fourier transform of $x\mapsto e^{-|x|}$ is given by $\omega\mapsto 1/(1+\omega^2)$ up to a constant multiplicative factor, we can identify the RKHS associated to this kernel with the classic Sobolev space $H^{-1}$, dual of the space $H^1$ of square-integrable functions with square-integrable derivative.

Exercise 3: What can you say about the Energy Distance kernel

\begin{align} k(x,y)~=~ -\|x-y\|~~ ? \end{align}

Does it satisfy the hypotheses above?

Solution 3: In dimension 1, the Fourier transform of $x\mapsto-|x|$ is given by an improper integral, $\omega\mapsto 1/\omega^2$. Consequently, it lies a bit outside of the simple theory of positive definite kernels: we can only say that it defines a conditionally positive definite kernel, and a meaningful norm between measures which have the same mass - thus avoiding the problem of evaluating the Fourier transform of $k\star(\alpha-\beta)$ at $\omega=0$.

In [10]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_distance("gaussian", s=.1) )
In [11]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_distance("gaussian", s=.5) )
In [12]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_distance("laplace", s=.5) )
In [13]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_distance("energy") )

Exercise 4: Compare the behaviours of these kernel norms for different formulas and scales. Discuss.

Solution 4: Two parameters influence the final results:

  • the spectral bandwidth of the kernel function $k$; if $k$ is smooth, $\widehat{k}(\omega)$ converges towards 0 at infinity and the kernel norm becomes blind to high frequencies; we only register $\alpha$ roughly onto $\beta$
  • the spatial support of the kernel function; is $k$ is too narrow, the $x_i$ stop interacting with the $y_j$'s and simply spread out to minimize the auto-correlation term $\langle \alpha,k\star\alpha\rangle$.

Noticeably, we observe a screening effect: some particles feel a very low gradient and only converge slooooowly towards $\beta$. This is best explained with the vocabulary of electrostatics: here, the $x_i$'s are particles with charge $+\alpha_i$, the $y_j$'s have a negative charge $-\beta_j$, the kernel function represents the interaction potential and the kernel norm is the total energy of the system.

The total force on a given particle $x_i$ is then given as the sum of a repulsion term from the other $x_i$'s, and the attraction towards the $y_j$'s. Since particles on the left end of $\alpha$ are repulsed nearly as much as they are attracted, they only move slowly towards the target.

In practice, to match sampled measures, we tend to choose kernel functions with:

  • A null derivative at zero, and a large enough "blurring radius" to prevent overfit on the precise sampled locations of diracs.
  • A heavy tail, to prevent isolated parts of $\alpha_t$ and $\beta$ from being "forgotten" by the gradient.

Using a Maximum-likelihood estimator

In the previous section, we've seen how to compare measures by seeing them as linear forms on a Sobolev-like space of functions. An other idea would be to see the measure $\beta$ as the realization of an i.i.d. sampling according to the measure $\alpha$, with likelihood

$$\text{Likelihood}_\alpha (\beta) ~=~ \text{Likelihood}_\alpha \bigg(\sum_j \beta_j \delta_{y_j}\bigg) ~=~ \prod_j \text{Likelihood}_\alpha(y_j)^{\beta_j}.$$

But which value could we attribute to the "likelihood of drawing $y$, given the measure $\alpha$"? Since $\alpha$ is discrete, supported by the $x_i$'s, interpreting it as a density wouldn't be practical at all... Thankfully, there's a simple solution: we could convolve $\alpha$ with a simple density function $k>0$ of mass 1 - say, a Gaussian - and thus end up with a probability measure $k\star \alpha$ which is absolutely continous wrt. the Lebesgue measure, with density

$$\text{Likelihood}_{k\star\alpha}(y) ~=~ \sum_i k(y-x_i) \,\alpha_i > 0 ~~ \text{ for all $y\in\mathbb{R}^d$}.$$

From a probabilistic point of view, using this density function as a "model" is equivalent to assuming that the random variable $y$ is generated as a sum $x + w$, where $x$ and $w$ are two independent variables of laws equal to $\alpha$ and $k\cdot\text{Lebesgue($\mathbb{R}^d$)}$. If $k$ is a Gaussian function, we speak of Gaussian Mixture Models.

Given $\alpha$, $\beta$ and a symmetric kernel function $k$, we can then choose to maximize the likelihood

$$\text{Likelihood}_{k\star\alpha} (\beta) ~=~ \prod_j \big( \sum_i k(x_i-y_j) \,\alpha_i \big)^{\beta_j},$$

i.e. to minimize the negative log-likelihood

$$\text{d}_{\text{ML},k}(\alpha,\beta) ~=~ - \sum_j \log\big( \sum_i k(x_i-y_j) \,\alpha_i \big)\, \beta_j.$$

Information theoretic interpretation. Before going any further, we wish to stress the link between maximum-likelihood estimators and the Kullback-Leibler divergence. In fact, if we assume that a measure $\beta_{\text{gen}}$ is absolutely continuous wrt. the Lebesgue measure $\lambda$, then

\begin{align} \text{H}(\beta_{\text{gen}}\,|\, \alpha) ~&=~ \int \log \bigg( \frac{\text{d} \beta_{\text{gen}}}{\text{d}\alpha} \bigg) \,\text{d}\beta_{\text{gen}} ~=~ \int \log \bigg( \frac{\text{d} \beta_{\text{gen}} / \text{d}\lambda }{\text{d}\alpha/ \text{d}\lambda} \bigg) \,\text{d}\beta_{\text{gen}}\\ ~&=~ \int \log \bigg( \frac{\text{d} \beta_{\text{gen}}}{\text{d}\lambda} \bigg) \,\text{d}\beta_{\text{gen}} ~-~ \int \log \bigg( \frac{\text{d} \alpha}{\text{d}\lambda} \bigg) \,\text{d}\beta_{\text{gen}} \end{align}

\begin{align} \text{i.e.}~~ \text{H}(\beta_{\text{gen}}\,|\, \alpha) ~&=~ \text{H}(\beta_{\text{gen}}\,|\, \lambda) ~-~ \int \log \bigg( \frac{\text{d} \alpha}{\text{d}\lambda} \bigg) \,\text{d}\beta_{\text{gen}}\\ \text{so that}~~ ~- \int \log \bigg( \frac{\text{d} \alpha}{\text{d}\lambda} \bigg) \,\text{d}\beta_{\text{gen}} ~&=~ \text{H}(\beta_{\text{gen}}\,|\, \alpha) - \text{H}(\beta_{\text{gen}}\,|\, \lambda) . \end{align}

Hence, as the sampled measure $\beta$ weakly converges towards a measure $\beta_{\text{gen}}$,

$$d_{\text{ML},k}(\alpha,\beta) ~\longrightarrow~ \text{H}(\beta_{\text{gen}}\,|\,k\star \alpha) ~-~ \text{H}(\beta_{\text{gen}}\,|\, \lambda). $$

As a function of $\alpha$, this formula is minimized if and only if $~~k\star \alpha = \beta_{\text{gen}}$.

Practical implementation. As noted by the careful reader, the maximum-likelihood cost $\text{d}_{\text{ML},k}(\alpha,\beta)$ can be computed as the scalar product between the vector of weights $(\beta_j)$ and the pointwise logarithm of the Kernel Product $\text{KP}\big( (y_j), (x_i), (\alpha_i) \big)$ - up to a negative sign. So, is using our KP routine a sensible thing to do? No, it isn't.

Indeed, if a point $y_j$ is far away from the support $\{x_i, \dots\}$ of the measure $\alpha$, $\sum_i k(x_i-y_j)\,\alpha_i$ can be prohibitively small. Just remember how fast a Gaussian kernel decreases to zero! If this sum's order of magnitude is close to the floating point precision (for float32 encoding, around $10^{-7}\simeq e^{-4^2}$), applying to it a logarithmic function is just asking for trouble.

Additive v. Multiplicative formulas. In the previous section, we defined the kernel distance $\text{d}_k$ and never encountered any accuracy problem. This is because, as far as sums are concerned, small "kernel's tail" values can be safely discarded - providing a reasonable balance in the weights' distribution. However, when using maximum likelihood estimation, all the values are multiplicated with each other: the smaller ones cannot be "neglected" anymore, as they very much determine the magnitude of the whole product. In the log-domain, near-zero values of the density $\big(k\star \alpha\big)(y_j)$ have a large influence on the final result!

Numerical stabilization. We now understand the importance of magnitude-independent schemes as far as multiplicative formulas are concerned. Programs which do not spiral out of control when applied to values of the order of $10^{-100}$. How do we achieve such robustness? For arbitrary expressions, the only solution may be to increase the memory footprint of floating-point numbers...

But in this specific "Kernel Product" case, a simple trick will do wonders: using a robust log-sum-exp expression. Let's write

$$U_i ~=~ \log(\alpha_i), ~~~~~~~ C_{i,j}~=~ \log\big( k(x_i-y_j)\big) ~~~ \text{(given as a stable explicit formula)}.$$

Then, the log-term in the ML distance can be written as

$$\log\big(k\star\alpha\big)(y_j) ~=~ \log\bigg( \sum_i k(x_i-y_j) \,\alpha_i \bigg) ~=~ \log\bigg( \sum_i \exp \big( C_{i,j} + U_i \big) \bigg).$$

This expression lets us see that the order of magnitude of $\big(k\star\alpha\big)(y_j)$ can be factored out easily. Simply compute

$$M_j~=~ \max_i C_{i,j} + U_i, ~~~~~ \text{and remark that} ~~~~~ \log\big(k\star\alpha\big)(y_j) ~=~ M_j \,+\, \log\bigg( \sum_i \exp \big( C_{i,j} + U_i - M_j \big) \bigg).$$

As the major exponent has been pulled out of the sum, we have effectively solved our accuracy problems. In practice, we can simply use the .logsumexp() reduction provided by recent versions of the PyTorch library.

In [14]:
def KP_log(x,y,β_j_log, p = 2, blur = 1.) :
    x_i = x[:,None,:]  # Shape (N,d) -> Shape (N,1,d)
    y_j = y[None,:,:]  # Shape (M,d) -> Shape (1,M,d)
    xmy = x_i - y_j    # (N,M,d) matrix, xmy[i,j,k] = (x_i[k]-y_j[k])
    if   p==2 : C =  - (xmy**2).sum(2) / (2*(blur**2))
    elif p==1 : C =  - xmy.norm(dim=2) / blur
    return (blur**p)*(C + β_j_log.view(1,-1)).logsumexp(1,keepdim=True)
In [15]:
def kernel_neglog_likelihood(p=2, blur = 1.) :
    def cost(α_i, x_i, β_j, y_j) :
        loglikelihoods = KP_log(y_j, x_i, α_i.log(), p, blur)
        dAB          = -scal(β_j, loglikelihoods)
        return dAB   
    return cost

Exercise 5: Why did we put a multiplicative factor blur**p in the definition of KP_log? What influence does it have on the gradient flow?

Solution 5: Denoting the blur by the standard letter $\sigma$, we've defined KP_log as a sampler of a function $f$ such that

\begin{align} f(x)~=~ \sigma^p \,\log \int_y \exp(-\tfrac{\|x-y\|^p}{p\,\sigma^p})\,\text{d}\beta(y) \end{align}

Since $\beta$ is a probability measure, $f(x)$ thus scales as $\tfrac{1}{p}\|x-y\|^p$ away from $y\sim\beta$; its gradient will then scale nicely for all values of $\sigma$, and we will get comparable flow dynamics.

In [16]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_neglog_likelihood(p=2, blur=.5) )
In [17]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_neglog_likelihood(p=2, blur=.1) )

Exercise 6: How would you describe the behaviour of this Cost functional? In the $\text{blur} \rightarrow 0$ limit, which simple formula do you recognize? What about the $\text{blur} \rightarrow +\infty$ limit? Can you explain the mode collapse observed for large values of the blurring parameter?

Solution 6: As implemented above,

\begin{align} \text{d}_{\text{ML},k}(\alpha,\beta)~=~ \langle \,\beta(y)\,,\, -\sigma^p \log\int \exp\big(-\tfrac{\|x-y\|^p}{p\sigma^p}\big) \,\text{d}\alpha(x)\,\rangle, \end{align} which can be rewritten as \begin{align} \langle \,\beta(y)\,,\, \min_{x\sim\alpha,\sigma^p} \tfrac{1}{p}\|x-y\|^p\,\rangle, \end{align}

where the SoftMin operator $\min_\varepsilon$ is defined through

\begin{align} \min_{x\sim\alpha, \varepsilon} \varphi(x)~ &=~ -\varepsilon \log \int \exp\big(- \varphi(x)/\varepsilon \big) \text{d}\alpha(x) \\ &\xrightarrow{\epsilon\rightarrow ~0~}~~ \min_{x\in\text{supp}(\alpha)}\varphi(x)\\ &\xrightarrow{\epsilon\rightarrow +\infty}~~ \int \varphi\,\text{d}\alpha. \end{align}

As we recognize a smooth interpolation between the min and the sum reduction, we can now make sense of the behavior of the GMM-MaxLikelihood functional:

  1. When $\sigma \rightarrow 0$, \begin{align} \text{d}_{\text{ML},k}(\alpha,\beta)~\rightarrow~ \tfrac{1}{p} \langle \,\beta(y)\,,\, \min_{i=1..\text{N}} \|y-x_i\|^p\,\rangle \end{align} which can be understood as a sum-Hausdorff loss that is only interested in putting some points $x_i$ in the neighborhood of $\beta$.

  2. When $\sigma \rightarrow +\infty$, \begin{align} \text{d}_{\text{ML},k}(\alpha,\beta)~\rightarrow~ \iint \tfrac{1}{p}\|y-x\|^p\,\text{d}\alpha(x)\,\text{d}\beta(y) ~=~ \langle\,\beta\,,\,\tfrac{1}{p}\|\,\cdot\,\|^p\star\alpha\,\rangle, \end{align} which is minimized when $\alpha$ is a Dirac atom located at the median (p=1) or mean (p=2) value of the target $\beta$.

Exercise 7: One could be tempted to symmetrize the maximum-likelihood cost, as implemented below. Discuss.

In [18]:
def kernel_sym_neglog_likelihood(p=2, blur=1) :
    def cost(α_i, x_i, β_j, y_j) :
        a_j = -KP_log(y_j, x_i, α_i.log(), p, blur)
        b_i = -KP_log(x_i, y_j, β_j.log(), p, blur)
        return scal(α_i,b_i) + scal(β_j, a_j)  
    return cost
In [19]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_sym_neglog_likelihood(p=1, blur=.1) )
In [20]:
def kernel_full_neglog_likelihood(p=2, blur=1) :
    def cost(α_i, x_i, β_j, y_j) :
        a_i = -KP_log(x_i, x_i, α_i.log(), p, blur)
        a_j = -KP_log(y_j, x_i, α_i.log(), p, blur)
        b_i = -KP_log(x_i, y_j, β_j.log(), p, blur)
        b_j = -KP_log(y_j, y_j, β_j.log(), p, blur)
        
        return scal(α_i, b_i-a_i) + scal(β_j, a_j-b_j) 
    return cost
In [21]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_full_neglog_likelihood(p=2, blur=.05) )
In [22]:
gradient_flow(α_i, x_i, β_j, y_j, kernel_full_neglog_likelihood(p=1, blur=.05) )

Solution 7: Both "fixes" rely on smooth distance fields \begin{align} a(y)~&=~ \min_{x\sim\alpha,\sigma^p}\tfrac{1}{p}\|x-y\|^p, & b(x)~&=~ \min_{y\sim\beta,\sigma^p}\tfrac{1}{p}\|x-y\|^p. \end{align} The first formula, \begin{align} \text{d}_{\text{ML-sym}}(\alpha,\beta)~=~ \langle\,\alpha\,,\,b\,\rangle~+~\langle\,\beta\,,\,a\,\rangle, \end{align} is the sum of terms that mean that "the $x_i$'s should be close to $\beta$" and "the $y_j$'s should be close to $\alpha$"... but stills suffers from mode collapse. The second fix, \begin{align} \text{d}_{\text{ML-full}}(\alpha,\beta)~=~ \langle\,\alpha-\beta\,,\,b-a\,\rangle~=~ \langle\,\alpha-\beta\,,\, \log \frac{k\star\alpha}{k\star\beta}\,\rangle, \end{align} with $k(x)=\exp(-\|x\|^p/p\sigma^p)$ is a better try as it mimicks the quadratic-like formula of kernel norms.

Unfortunately though, it can be shown simply that this "log-kernel", Hausdorff-like formula does not define a positive definite divergence between measures. Generically, there exists a measure $\alpha\neq\beta$ such that \begin{align} \text{d}_{\text{ML-full}}(\alpha,\beta) < \text{d}_{\text{ML-full}}(\beta,\beta) = 0. \end{align}

Using an Optimal Transport distance

In the previous two sections, we've seen how to compute kernel distances, which are the duals of Sobolev-like norms on space of functionals, as well as Maximum Likelihood scores for Gaussian-Laplace Mixture Models, which can be understood as soft generalizations of the integrated Hausdorff/Chamfer distance.

Last but not least, we now show how to compute Optimal Transport plans efficiently, ending up on Wasserstein-like distances between unlabeled measures.

Getting used to Optimal Transport. The modern OT theory relies on a few objects and problems that we now briefly recall. For a complete reference on the subject, you may find useful Filippo Santambrogio's Optimal Transport for Applied Mathematicians (2015) or Peyré-Cuturi's Computational Optimal Transport (2017), depending on your background.

Kantorovitch problem. Given $\alpha = \sum_i \alpha_i \,\delta_{x_i}$ and $\beta = \sum_j \beta_j\,\delta_{y_j}$ we wish to find a Transport Plan $\pi$ (a measure on the product $\{x_i,\dots\}\times\{y_j,\dots\}$, encoded as an $\text{N}$-by-$\text{M}$ matrix $(\pi(x_i\leftrightarrow y_j))=(\pi_{i,j})$) which is a solution of the following optimization problem:

$$\text{minimize} ~~~ \langle \pi, C \rangle ~=~ \sum_{i,j} \pi_{i,j} C_{i,j}$$ $$\text{subject to:} ~~~~ \forall\,i,j,~~ \pi_{i,j} \geqslant 0, ~~ \sum_j \pi_{i,j} ~=~ \alpha_i, ~~ \sum_i \pi_{i,j} ~=~ \beta_j,$$

where the Cost matrix $C_{i,j} = c(x_i,y_j)$ encodes the cost of moving a unit mass from point $x_i$ to point $y_j$.

Wasserstein distance. If one uses $c(x_i,y_j)=\|x_i-y_j\|^2$, the optimal value of the above problem is called the Wasserstein distance $\text{d}_{\text{Wass}}(\alpha,\beta)$ between measures $\alpha$ and $\beta$. Its theoretical properties are plentiful... But can we compute it efficiently? In the general high-dimensional case: no, we can't. Indeed, the Kantorovitch problem above is a textbook Linear optimization problem, combinatorial by nature. Even though the simplex algorithm or other classical routines output exact solutions, they do so at a prohibitive cost: at least cubic wrt. the number of samples.

Entropic regularization. Thankfully, we can however compute approximate transport plans at a much lower cost. Given a small regularization parameter $\varepsilon$, the idea is to add an entropic barrier to the Linear Kantorovitch problem and solve

$$\text{minimize} ~~~ \langle \pi, C \rangle + \varepsilon \text{KL}(\pi, \alpha\otimes\beta) ~=~ \sum_{i,j} \pi_{i,j} C_{i,j} + \varepsilon\,\sum_{i,j} \big[ \pi_{i,j}\, \log \frac{\pi_{i,j}}{\alpha_i\beta_j} \,-\,\pi_{i,j}\,+\,\alpha_i\beta_j\big]$$ $$\text{subject to:} ~~~~ \forall\,i,j,~~ \pi_{i,j} \geqslant 0, ~~ \sum_j \pi_{i,j} ~=~ \alpha_i, ~~ \sum_i \pi_{i,j} ~=~ \beta_j.$$

An important property of the $x\mapsto x\log x - x + 1$ function is that it has a $-\infty$ derivative at location $x=0$. Since the main objective function $\langle \pi, C\rangle$ is linear wrt. the $\pi_{i,j}$'s, this implies that the optimal value of the regularized problem is attained in the relative interior of the simplex, defined by the constraints:

$$ \pi ~>~0, ~~~ \pi 1 ~=~ \alpha, ~~~ \pi^T 1 ~=~\beta.$$

Hence, the optimum is necessarily reached at a critical point of our constrained problem. At the optimum $\pi^{\star}$, the gradient of the objective can thus be written as a linear combination of the equality constraints' gradients:

$$\exists\, (f^\star_i)\in\mathbb{R}^\text{N}, ~(g^\star_j)\in\mathbb{R}^\text{M},~~ \forall\,i,j,~~~ C_{i,j} + \varepsilon \,\log \frac{\pi_{i,j}^{\star}}{\alpha_i\beta_j}~=~ f^\star_i + g^\star_j $$

where $f^\star_i$ is the coefficient associated to the constraint $\sum_j \pi_{i,j} = \alpha_i$, as $g^\star_j$ is linked to $\sum_i \pi_{i,j} = \beta_j$.

All in all, we see that the optimal transport plan $(\pi^{\star}_{i,j}) \in \mathbb{R}_+^{\text{N}\times \text{M}}$ is characterized by a single pair of vectors $(f^\star,g^\star) \in \mathbb{R}^{\text{N}+\text{M}}$:

$$\forall\,i,j,~~~~~ \log \frac{\pi^{\star}_{i,j}}{\alpha_i\beta_j} ~=~ (f^\star_i+g^\star_j-C_{i,j})/\varepsilon$$

$$\text{i.e.}~~~ \pi^{\star}~=~ \text{diag}(\alpha_iU_i)\,K_{i,j}\,\text{diag}(V_j\beta_j)$$

$$\text{with}~~~ U_i~=~ \exp(f^\star_i/\varepsilon), ~~ V_j~=~\exp(g^\star_j/\varepsilon), ~~ K_{i,j}~=~\exp(-C_{i,j}/\varepsilon).$$

Consequences of this "critical point equation" are twofold:

  • The dimension of the space in which we should search the optimal transport plan is greatly reduced, jumping from $(\text{M}\times \text{N})$ to $(\text{M}+\text{N})$ - the adjoint variables associated to the equality constraints. Furthermore, the optimal value of the cost can be computed using this cheap formula:

    \begin{align} \text{OT}_\varepsilon(\alpha,\beta)~&=~ \langle \pi^{\star}, C\rangle + \varepsilon \text{KL}(\pi^{\star},\alpha\otimes\beta) ~=~ \sum_{i,j} \pi_{i,j}^\star \,(f_i^\star + g_j^\star) ~=~ \langle \alpha, f^\star \rangle + \langle \beta, g^\star\rangle. \end{align}

  • The optimal transport plan can be expressed as the positive scaling of a positive kernel matrix $K$. But in the meantime, it should also satisfy the two marginal constraints which can be written in terms of $(U_i)$ and $(V_j)$:

$$ U_i ~=~ \frac{1}{(K(V\beta))_i}, ~~~~ V_j ~=~ \frac{1}{(K^T(U\alpha))_j}.$$

As was remarked by a long trail of authors (from Schrödinger's original work, to economy and statistical physics in the 60-80-90-00's, to object recognition in the 90-00's and more recently in the machine learning literature) this reformulation of (entropic regularized) Optimal Transport can be linked to the Sinkhorn Theorem: It admits a unique solution $(U,V) \in \mathbb{R}_{>0}^{\text{N}+\text{M}}$, which can be approached iteratively by applying the steps

$$ U^{(0)} = (1,\dots, 1),~ V^{(0)} = (1,\dots, 1), ~~~~~~ V^{(n+1)} ~=~ \frac{1}{K^T(U^{(n)}\alpha)}, ~~~~ U^{(n+1)} ~=~ \frac{1}{K(V^{(n+1)}\beta)}.$$

These are nothing but coordinate ascent steps on the dual maximization problem:

\begin{align} \text{OT}_\varepsilon(\alpha,\beta)~&=~\max_{f,g}~\langle\alpha,f\rangle ~+~\langle\beta,g\rangle~-~\varepsilon\,\langle\alpha\otimes\beta, e^{(f\oplus g-C)/\varepsilon} - 1\rangle. \end{align}

Hence, one can solve the regularized Optimal Transport problem by iterating kernel products (aka. discrete convolutions) and pointwise divisions, on variables which have the same memory footprint as the input measures!

Sinkhorn algorithm in the log domain

Is the scheme presented above stable enough? No, it isn't. Indeed, as discussed in the section dedicated to Maximum likelihood estimators, if we use kernels in multiplicative formulas, we should favor log-domain implementations.

In [23]:
D = lambda x : x.detach() # use the formula at convergence for the gradient

def ot_reg(p = 2, blur = .05, scaling=.5 ) :
    def cost(α_i, x_i, β_j, y_j) :
        # ε-scaling heuristic (aka. simulated annealing): 
        # let ε decrease across iterations, from 1 (=diameter) to the target value
        scales = [ tensor([np.exp(e)]) for e in 
                   np.arange(1, np.log(blur), np.log(scaling)) ] + [blur]
        
        # Solve the OT_ε(α,β) problem
        f_i, g_j = torch.zeros_like(α_i), torch.zeros_like(β_j)
        for scale in scales :
            g_j = -KP_log(y_j, D(x_i), D(f_i/scale**p + α_i.log()), p=p, blur=scale)
            f_i = -KP_log(x_i, D(y_j), D(g_j/scale**p + β_j.log()), p=p, blur=scale)
        
        # Return the dual cost OT_ε(α,β), assuming convergence in the Sinkhorn loop
        return scal(α_i, f_i) + scal(β_j, g_j)
    return cost

Exercise 8: Explain why the implementation above is correct and numerically stable.

Solution 8: On the dual variables $f^{(n)}=\varepsilon\log U^{(n)}$ and $g^{(n)}=\varepsilon\log V^{(n)}$, the Sinkhorn iterations read

\begin{align} g_j^{(n+1)}~&=~ -\varepsilon \log K^T(U^{(n)}\alpha)\\ ~&=~ -\varepsilon \log \sum_{i=1}^\text{N} \exp \big( -\|x_i-y_j\|^p/p\varepsilon+f_i^{(n)}/\varepsilon +\log\alpha_i \big)\\ g_j^{(n+1)}~&=~ ~\min_{x\sim\alpha,\varepsilon}\big[ \tfrac{1}{p}\|y_j-x\|^p - f^{(n)}(x) \big], \\~\\ f_i^{(n+1)}~&=~ -\varepsilon \log K(V^{(n+1)}\beta)\\ ~&=~ -\varepsilon \log \sum_{j=1}^\text{M} \exp \big( -\|x_i-y_j\|^p/p\varepsilon+g_j^{(n+1)}/\varepsilon +\log\beta_i \big)\\ f_i^{(n+1)}~&=~ ~\min_{y\sim\beta,\varepsilon}\big[ \tfrac{1}{p}\|x_i-y\|^p - g^{(n+1)}(y) \big], \end{align}

which is what is implemented here with $\varepsilon = \sigma^p$. Crucially, if the log-sum-exp reduction is implemented properly, this code won't suffer from numerical overflows even if $U^{(n)}=e^{\pm 100}$...

Exercise 9: Link this implementation with the log-likelihood cost presented in the previous section. Intuitively, could you explain the behaviour of this algorithm? Why is it much faster than the "standard" Sinkhorn algorithm, with a fixed value of $\varepsilon$? How could you improve it further?

Solution 9: As detailed in the previous answer, the Sinkhorn iterations can now be understood with quantities that are homogeneous to the cost $\tfrac{1}{p}\|x-y\|^p$: the prices $f$ and $g$. The $\min_\varepsilon$ updates now resemble closely those of standard combinatorial methods such as the Auction algorithm, and can be studied accordingly: see Kosowsky and Yuille (1993) and Schmitzer (2016) for reference.

Two intuitions arise from this analysis:

  1. Before reaching convergence, Sinkhorn updates typically make steps of size $\varepsilon$ in the maximization of the dual cost. By using larger values of $\varepsilon$ in the first few iterations, we make larger strides and reach quickly the fine-tuning, end-game regime.
  2. Optimal Transport is fundamentally a multiscale problem: a rough transport plan computed at a coarse scale can always be refined into a finer correspondence. This is what the $\varepsilon$-scaling heuristic is all about, as we lower the amount of blur (or temperature) from one iteration to the other.

To improve this algorithm further, we could remark that the updates at any iteration are typically "$\varepsilon$-smooth". In the first few iterations, we could thus work on subsampled measures and develop a fully-fledged multiscale algorithm.

In [24]:
gradient_flow(α_i, x_i, β_j, y_j, ot_reg(p=2, blur=.05) )
In [25]:
gradient_flow(α_i, x_i, β_j, y_j, ot_reg(p=2, blur=.25) )

Exercise 10: In the experience above, can you explain the entropic bias, which pushes $\alpha_t$ away from $\beta$, onto a medial-axis like measure with a narrow support?

Solution 10: We know that $\text{OT}_\varepsilon(\alpha,\beta)$ is (roughly) equal to the transport cost associated to a fuzzy transport plan \begin{align} \pi^\star~=~\exp\tfrac{1}{\varepsilon}(f^\star\oplus g^\star-C)\,\cdot\,\alpha\otimes\beta, \end{align} which typically links any point $x_i$ to an $\varepsilon$-ball of points $y_j$ in $\beta$. As we minimize the sum of (squared) lengths associated to this fuzzy "system of springs", points $x_i$ tend to converge towards the median (if p=1) or mean (if p=2) value of their $\varepsilon$-mates, which often lies deep inside the convex hull of $\beta$'s support.

To solve this problem, an idea is to define the Sinkhorn divergence in a way that mimicks the bilinear expansion of squared Euclidean norms:

\begin{align} \text{S}_\varepsilon(\alpha,\beta)~=~ \text{OT}_\varepsilon(\alpha,\beta) ~-~\tfrac{1}{2}\text{OT}_\varepsilon(\alpha,\alpha) ~-~\tfrac{1}{2}\text{OT}_\varepsilon(\beta,\beta). \end{align}

Most interestingly, we then get that Sinkhorn divergences interpolate between Optimal Transport and kernel norms: \begin{align} \text{OT}_C(\alpha,\beta) ~~ \xleftarrow{0\leftarrow \varepsilon} ~~ \text{S}_\varepsilon(\alpha,\beta) ~~ \xrightarrow{\varepsilon\rightarrow +\infty} ~~ \tfrac{1}{2}\|\alpha-\beta\|_{-C}^2. \end{align}

In 2018, it was shown that under mild assumptions, $\text{S}_\varepsilon$ defines a symmetric, positive-definite divergence which is convex with respect to each variable and metrizes the convergence in law. In particular, the entropic bias is removed and our gradient flow converges towards $\beta$, up to the high-frequency components lost when seeing both measures through the blurring convolution kernel

\begin{align} k_\varepsilon(x,y)~~ = ~~ e^{-C(x,y)/\varepsilon}. \end{align}

In [26]:
def sinkhorn_divergence(p = 2, blur = .05, scaling=.5 ) :
    def cost(α_i, x_i, β_j, y_j) :
        # ε-scaling heuristic (aka. simulated annealing): 
        # let ε decrease across iterations, from 1 (=diameter) to the target value
        scales = [ tensor([np.exp(e)]) for e in 
                   np.arange(0, np.log(blur), np.log(scaling)) ] + [blur]
        
        # 1) Solve the OT_ε(α,β) problem
        f_i, g_j = torch.zeros_like(α_i), torch.zeros_like(β_j)
        for scale in scales :
            g_j = - KP_log(y_j, D(x_i), D(f_i/scale**p + α_i.log()), p=p, blur=scale)
            f_i = - KP_log(x_i, D(y_j), D(g_j/scale**p + β_j.log()), p=p, blur=scale)
        
        # 2) Solve the OT_ε(α,α) and OT_ε(β,β) problems
        scales_sym = [scale]*3 # Symmetric updates converge very quickly
        g_i, f_j = torch.zeros_like(α_i), torch.zeros_like(β_j)
        for scale in scales_sym :
            g_i=.5*(g_i - KP_log(x_i, x_i, g_i/scale**p + α_i.log(), p=p, blur=scale))
            f_j=.5*(f_j - KP_log(y_j, y_j, f_j/scale**p + β_j.log(), p=p, blur=scale))
        # Final step, to get a nice gradient in the backprop pass:
        g_i = - KP_log(x_i, D(x_i), D(g_i/scale**p + α_i.log()), p=p, blur=scale)
        f_j = - KP_log(y_j, D(y_j), D(f_j/scale**p + β_j.log()), p=p, blur=scale)
        
        # Return the "dual" cost :
        # S_ε(α,β) =        OT_ε(α,β)       - ½OT_ε(α,α) - ½OT_ε(β,β)
        #          = (〈α,f_αβ〉+〈β,g_αβ〉) -  〈α,g_αα〉 - 〈β,f_ββ〉
        return scal(α_i, f_i - g_i) + scal(β_j, g_j - f_j)
    return cost

Exercise 11: Explain why the implementation above is correct.

Solution 11: We've already detailed the step 1 ($\text{OT}_\varepsilon(\alpha,\beta)$ problem) and now focus on the symmetric case of the $\alpha\leftrightarrow\alpha$ problem ($\beta\leftrightarrow\beta$ can be handled identically). We know that \begin{align} \text{OT}_\varepsilon(\alpha,\alpha)~& =~\max_{f,g}~\langle\alpha,f+g\rangle ~-~\varepsilon\,\langle\alpha\otimes\alpha, e^{(f\oplus g-C)/\varepsilon} - 1\rangle\\ & =~2\,\max_{g}~\langle\alpha,g\rangle ~-~\tfrac{\varepsilon}{2}\,\langle\alpha\otimes\alpha, e^{(g\oplus g-C)/\varepsilon} - 1\rangle, \end{align} because the $(f,g)$ problem is concave and symmetric with respect to a permutation of the dual potentials: there exists a solution $(f=g,g)$ on the diagonal of the space $\mathbb{R}^\text{N}\times\mathbb{R}^\text{N}$ of dual pairs.

How can we find such a solution efficiently? Given a current estimate $(g^{(n)},g^{(n)})$, we know that defining \begin{align} \overline{g}^{(n+1)}_i~=~\min_{x\sim\alpha,\varepsilon}\big[ \tfrac{1}{p}\|x_i-x\|^p-g^{(n)}_i \big] \end{align} and jumping to $(g^{(n)},\overline{g}^{(n+1)})$ or $(\overline{g}^{(n+1)},g^{(n)})$ would bring us closer to the optimum, as this standard Sinkhorn update is a coordinate ascent step on the dual problem.

Going further, averaging these two updates by setting \begin{align} g^{(n+1)}~=~\tfrac{1}{2}(g^{(n)}+\overline{g}^{(n+1)}) \end{align} and jumping to $(g^{(n+1)},g^{(n+1)})$ is an even better idea: thanks to the concavity of the dual objective, we know that this competitor is at least as good as $(g^{(n)},\overline{g}^{(n+1)})$ and $(\overline{g}^{(n+1)},g^{(n)})$... And at the same time, it belongs to the diagonal of the space of dual pairs, where the global optimum is known to lie. Empirically, we always converge to a good enough solution in three or four steps.

In [27]:
gradient_flow(α_i, x_i, β_j, y_j, sinkhorn_divergence(p=2, blur=.01) )
In [28]:
gradient_flow(α_i, x_i, β_j, y_j, sinkhorn_divergence(p=2, blur=.2) )
In [29]:
gradient_flow(α_i, x_i, β_j, y_j, sinkhorn_divergence(p=1, blur=.2) )

Exercise 12: Discuss the behaviour of $\text{S}_\varepsilon$ for varying values of the parameters.

Solution 12: When $\varepsilon$ is small, we retrieve the behavior of the "true" Wasserstein distance at an affordable cost. However, as $\varepsilon$ grows, $\text{S}_\varepsilon$ behaves more and more like the kernel norm $\tfrac{1}{2}\|\alpha-\beta\|^2_{-\tfrac{1}{p}\|\cdot\|^p}$:

  • if p=2, $\|\alpha-\beta\|^2_{-\|\cdot\|^2/2}~=~\tfrac{1}{2}\|\text{mean}(\alpha)-\text{mean}(\beta)\|_2^2$: the divergence becomes blind to fine details and only registers the first moments with each other.
  • if p=1, $\tfrac{1}{2}\|\alpha-\beta\|^2_{-\|\cdot\|}$ is the Energy distance, a kernel norm that was studied in the first section and presents screening artifacts.

Fortunately, in all cases, the entropic bias is alleviated and we do not observe any mode collapse.

Conclusion

In this notebook, we presented three major families of "distance" costs between probability measures:

  • Kernel distances (also known as Maximum Mean Discrepancies), which descend from dual norms on Sobolev-like spaces of functions.
  • Empirical log-likelihoods of mixture models, which are smooth generalizations of Hausdorff-like distances.
  • Sinkhorn divergences, designed as cheap approximations of Optimal Transport costs; they can be linked to dual norms on Lipschitz-like spaces of functions.

Interestingly, the three of them use the same atomic operation: the Kernel Product, possibly implemented in the log-domain. As it is a GPU-friendly operation, using these formulas results in scalable algorithm: use a vanilla PyTorch implementation for clouds of <2,000 samples, and the powerful KeOps library for larger (10,000-1,000,000) problems.

But which formula should we use in practical applications? This is the main question, that can only be answered with respect to specific applications and datasets.

A rule of thumb: kernel distances and log-likelihoods are both cheap, and differ in the way they handle outliers and isolated $x_i$'s or $y_j$'s. With kernel distances, they are nearly forgotten as the tail of the kernel function $k$ goes to zero; in the likelihood case, though, since $-\log k(x) \rightarrow +\infty$ when $x$ grows, outliers tend to have a large influence on the overall cost and its gradient. More recently introduced, Sinkhorn distances tend to be both interpretable and robust... At a higher computational cost (multiscale tree-based approaches are super-efficient in dimensions 2 or 3, but break down for high-dimensional problems). Computing a full transport plan to simply get a gradient is overkill for most applications, and future works will certainly focus on cheap approximations that interpolate between OT and simpler theories.

To get your own intuition on the subject, feel free to re-run this notebook with measures sampled from your own sketches!