2) Statistical analysis on a shape space

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import plotly
# run at the start of every notebook
plotly.offline.init_notebook_mode() 

# Mandatory imports...
import numpy as np
from numpy   import random
import torch
from torch import tensor
from torch.nn import Parameter

# Custom modules:
from kendall_triangles import KendallTriangles # Fancy visualization
from model import Model  # Optimization blackbox
from display import plot # Simple plotting routine for triangles

# Finally, emulate complex numbers with pairs of floats:
from torch_complex import rot,normalize,herm,angle,mod2,mod,comp 

A) Computing a Fréchet mean

In the previous notebook, we've seen that the set of triangles "up to similarities" can be endowed with a canonical, spherical metric structure. The arc (or geodesic) distance between two triangle shapes is given by:

In [2]:
def arc( A, B ) :
    "'Geodesic' distance between two triangles A and B."
    a, b = normalize(A), normalize(B)
    return mod( herm(a,b) ).acos()

We now implement some (simple) statistical operations in this nonlinear space. Our key message: the standard statistician's toolbox can be generalized to Riemannian manifolds, both from a theoretical and a practical point of view.

First, let us generate some arbitrary population of shapes:

In [3]:
N = 100
A_k = torch.stack((
    .05*torch.randn(100,2)+tensor([-.5,-.5]),
    .05*torch.randn(100,2)+tensor([+.5,0. ]),
    .50*torch.randn(100,2)+tensor([ 0., 1.]),
), 1) + 3*torch.randn(100,1,2)

And display it in the 2D plane:

In [4]:
plt.figure()
plot(A_k[:20], "green")
plt.axis("equal")
plt.show()

In the shape space:

In [5]:
M = KendallTriangles()

M.add_markers(A_k, "green", "A_k", .5) 

M.generate_spherical_sampled_data()
M.show()
M.show_glyphs()
M.show_pieces()
M.iplot('Our triangles on the Kendall sphere')

Given such an empirical distribution $A_k$, what it the simplest statistic that we could compute? The mean, obviously. But beware: on a sphere, there is no such thing as a "$+$" operator.

One of the most sensible ways of generalizing the mean to an arbitrary metric space $(X,\text{d})$ is the least-square definition. For any order $p>0$, we thus define the $p$-Fréchet mean of the $A_k$'s through:

$$ F_p(A_1,\dots,A_N) ~=~ \arg\min_{x\in X} \,\sum_{k=1}^N \text{d}(x,A_k)^p $$

Exercise 1: Show that in a Euclidean vector space, the Fréchet means of order 1 and 2 respectively coincide with the median and the arithmetic mean of the distribution.

In [6]:
class FrechetMean(Model) :
    "Find the Frechet mean of a population of polygons."
    def __init__(self, guess, p=2) :
        "Defines an initial guess and the Frechet exponent."
        super(Model, self).__init__()
        self.x = Parameter(guess)
        self.p = p

    def __call__(self) :
        "Returns the current guess."
        return self.x
    
    def cost(self, targets) :
        "Returns the mean p-distance to the current guess."
        return ( arc(targets, self())**self.p ).mean()

Using PyTorch, computing a Fréchet mean is that simple:

In [7]:
# Initial guess: any non-degenerate triangle should be ok
A0 = tensor([ [-1.,0.], [1.,0.], [0.,1.] ])

# Frechet means of order 1 and 2
FMedian = FrechetMean( A0, p=1 )
FMean   = FrechetMean( A0, p=2 ) 
FMedian.fit( A_k )
FMean.fit(   A_k )

# For a clean display, let's normalize everybody
median = normalize( FMedian())
mean   = normalize( FMean()  )
a_k    = normalize( A_k      )
It  1:0.502. It  2:0.225. It  3:0.219. It  4:0.218. It  5:0.218. 
It  6:0.218. It  7:0.218. It  8:0.218. It  9:0.218. It 10:0.218. 
It 11:0.218. It 12:0.218. It 13:0.218. It 14:0.218. It 15:0.218. 
It 16:0.218. It 17:0.218. It 18:0.218. It 19:0.218. It 20:0.218. 
It 21:0.218. It 22:0.218. It 23:0.218. It 24:0.218. It 25:0.218. 
It 26:0.218. It 27:0.218. It 28:0.218. It 29:0.218. It 30:0.218. 
It 31:0.218. It 32:0.218. It 33:0.218. It 34:0.218. It 35:0.218. 
It 36:0.218. It 37:0.218. It 38:0.218. It 39:0.218. It 40:0.218. 
It 41:0.218. It 42:0.218. 
b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
It  1:0.275. It  2:0.068. It  3:0.063. It  4:0.063. It  5:0.063. 

b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'

Without aligning the shapes onto a reference mean, displaying the population is not easy:

In [8]:
plt.figure()
plot( a_k[:10] , "green")
plot( mean,      "red")
plot( median,    "blue")
plt.axis("equal")
plt.show()

Thankfully, the Kendall sphere lets us check that our computations are correct:

In [9]:
M = KendallTriangles()

M.add_markers(a_k,    "green", "A_k", .5) 
M.add_markers(mean,   "red",   "mean") 
M.add_markers(median, "blue",  "median") 

M.generate_spherical_sampled_data()
M.show()
M.show_glyphs()
M.show_pieces()
M.iplot('Mean and Median of our population')