%matplotlib inline
import matplotlib.pyplot as plt
import plotly
# run at the start of every notebook
plotly.offline.init_notebook_mode()
# Mandatory imports...
import numpy as np
from numpy import random
import torch
from torch import tensor
from torch.nn import Parameter
# Custom modules:
from kendall_triangles import KendallTriangles # Fancy visualization
from model import Model # Optimization blackbox
from display import plot # Simple plotting routine for triangles
# Finally, emulate complex numbers with pairs of floats:
from torch_complex import rot,normalize,herm,angle,mod2,mod,comp
In the previous notebook, we've seen that the set of triangles "up to similarities" can be endowed with a canonical, spherical metric structure. The arc (or geodesic) distance between two triangle shapes is given by:
def arc( A, B ) :
"'Geodesic' distance between two triangles A and B."
a, b = normalize(A), normalize(B)
return mod( herm(a,b) ).acos()
We now implement some (simple) statistical operations in this nonlinear space. Our key message: the standard statistician's toolbox can be generalized to Riemannian manifolds, both from a theoretical and a practical point of view.
First, let us generate some arbitrary population of shapes:
N = 100
A_k = torch.stack((
.05*torch.randn(100,2)+tensor([-.5,-.5]),
.05*torch.randn(100,2)+tensor([+.5,0. ]),
.50*torch.randn(100,2)+tensor([ 0., 1.]),
), 1) + 3*torch.randn(100,1,2)
And display it in the 2D plane:
plt.figure()
plot(A_k[:20], "green")
plt.axis("equal")
plt.show()
In the shape space:
M = KendallTriangles()
M.add_markers(A_k, "green", "A_k", .5)
M.generate_spherical_sampled_data()
M.show()
M.show_glyphs()
M.show_pieces()
M.iplot('Our triangles on the Kendall sphere')
Given such an empirical distribution $A_k$, what it the simplest statistic that we could compute? The mean, obviously. But beware: on a sphere, there is no such thing as a "$+$" operator.
One of the most sensible ways of generalizing the mean to an arbitrary metric space $(X,\text{d})$ is the least-square definition. For any order $p>0$, we thus define the $p$-Fréchet mean of the $A_k$'s through:
$$ F_p(A_1,\dots,A_N) ~=~ \arg\min_{x\in X} \,\sum_{k=1}^N \text{d}(x,A_k)^p $$
Exercise 1: Show that in a Euclidean vector space, the Fréchet means of order 1 and 2 respectively coincide with the median and the arithmetic mean of the distribution.
class FrechetMean(Model) :
"Find the Frechet mean of a population of polygons."
def __init__(self, guess, p=2) :
"Defines an initial guess and the Frechet exponent."
super(Model, self).__init__()
self.x = Parameter(guess)
self.p = p
def __call__(self) :
"Returns the current guess."
return self.x
def cost(self, targets) :
"Returns the mean p-distance to the current guess."
return ( arc(targets, self())**self.p ).mean()
Using PyTorch, computing a Fréchet mean is that simple:
# Initial guess: any non-degenerate triangle should be ok
A0 = tensor([ [-1.,0.], [1.,0.], [0.,1.] ])
# Frechet means of order 1 and 2
FMedian = FrechetMean( A0, p=1 )
FMean = FrechetMean( A0, p=2 )
FMedian.fit( A_k )
FMean.fit( A_k )
# For a clean display, let's normalize everybody
median = normalize( FMedian())
mean = normalize( FMean() )
a_k = normalize( A_k )
Without aligning the shapes onto a reference mean, displaying the population is not easy:
plt.figure()
plot( a_k[:10] , "green")
plot( mean, "red")
plot( median, "blue")
plt.axis("equal")
plt.show()
Thankfully, the Kendall sphere lets us check that our computations are correct:
M = KendallTriangles()
M.add_markers(a_k, "green", "A_k", .5)
M.add_markers(mean, "red", "mean")
M.add_markers(median, "blue", "median")
M.generate_spherical_sampled_data()
M.show()
M.show_glyphs()
M.show_pieces()
M.iplot('Mean and Median of our population')