From Wavelet Transforms to Convolutional Neural Networks - part 2

Author: Jean Feydy

In the second part of the workshop, we focus on a more realistic classification problem in which the FashionMNIST pieces of clothing are not perfectly centered. This is relevant as in practice, correct segmentation/extraction of image parts is at least as difficult as the classification of normalized+centered images...

In [1]:
# PyTorch import
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.autograd import Variable
import itertools

# Performance monitoring
from time import process_time
import matplotlib.pyplot as plt
import numpy as np
%matplotlib nbagg

# Disable warnings from the Scattering transform...
import warnings
warnings.filterwarnings("ignore")

# Train and visualize the performances of our models
from model_utils import AttrDict, show, generate_image, train, test, evaluate_model, display_classified_images
import model_utils as MU

MU.display_parameters()
# MU.use_cuda = False
# MU.args.epochs = 1 # If you just can't wait...
Optimizations will be made using the following arguments:
 {'epochs': 10, 'log_interval': 10, 'test_batch_size': 250, 'batch_size': 250, 'momentum': 0.5, 'learning_rate': 0.01}
Using CUDA ?  Yes
Random seed :  1
Remember that you can change those values dynamically!
(MU.args.batch_size = ...)

Loading the dataset - and applying random translations

Notice how we take advantage of the nice PyTorch syntax to apply randomized transformations every time an image is loaded.

In [2]:
# Load the Fashion_MNIST dataset

DATASET = "./fashion_MNIST/" # Storage location
MU.class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 
                  'Sandal',  'Shirt',   'Sneaker',  'Bag',   'Ankle boot' ]
datasets.MNIST.urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz'
]
kwargs = {'num_workers': 2, 'pin_memory': True} if MU.use_cuda else {}

MU.imgsize = (56,56)
MU.args.batch_size      = 250
MU.args.test_batch_size = 250

MU.train_dataset = datasets.MNIST(DATASET, train=True, download=True,  # Use the training-MNIST dataset
                       transform=transforms.Compose([
                       transforms.RandomCrop(MU.imgsize, padding=28),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,)) # with normalized values
                   ]))
MU.test_dataset  = datasets.MNIST(DATASET, train=False, transform=transforms.Compose([
                       transforms.RandomCrop(MU.imgsize, padding=28),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
# Iterator for the training pass:
MU.train_loader = torch.utils.data.DataLoader( MU.train_dataset,
    batch_size=MU.args.batch_size, shuffle=True, # Data is drawn as randomized minibatches
    **kwargs)                                 # Practical settings for parallel loading

# Iterator for the testing pass, with the same settings:
MU.test_loader  = torch.utils.data.DataLoader( MU.test_dataset,
    batch_size=MU.args.test_batch_size, shuffle=True, 
    **kwargs)
In [3]:
# Display a few samples from the dataset.
# N.B.: since we've normalized the samples, this looks pretty ugly...

nrow = 4
images = [[] for i in range(10)]
for i in range(10000) :
    images[MU.train_dataset[i][1]].append( MU.train_dataset[i][0] )
    full = [len(l) >= nrow for l in images]
    if all(full) :
        break
show(utils.make_grid( list(itertools.chain(*[[ images[i][j] for j in range(nrow)] for i in range(10)])), nrow=nrow, normalize=True))

First try : a good old two-layer perceptron

In [4]:
class TwoFullNet(nn.Module) :
    """
    Implements a simplistic perceptron with 3 layers :
    - one input, of size 28x28 (MNIST dataset)
    - one hidden, of size N
    - one output, of size 10 (number of classes)
    There is no built-in regularization, and we model the two
    transformations input->hidden and hidden->output as
    Linear->ReLu and Linear->SoftMax operators,
    i.e. as Fully connected computational graphs.
    The trainable parameters are the weights of the matrices
    (+ biases) involved in the "Linear" (Affine, really) operators.
    """
    def __init__(self, N) :
        "Defines the parameters of the model."
        super(TwoFullNet, self).__init__()
        # Linear (i.e. fully connected) layer, a matrix of size (28*28)xN
        self.fc1        = nn.Linear(MU.imgsize[0]*MU.imgsize[1], N)
        # Linear (i.e. fully connected) layer, a matrix of size Nx10 (10 classes as output)
        self.fc2        = nn.Linear( N, 10)

    def forward(self, x) :
        """
        Apply the model to some input data x.
        You can think of x as an image of size 28x28, but it is
        actually an Mx28x28 tensor, where M is the size of the
        mini-batch.
        """
        x = x.view(-1, MU.imgsize[0]*MU.imgsize[1]) # Turns our image into a vector
        x = self.fc1( x )     # Linear transformation
        x = F.relu(   x )     # Non-linearity (Relu = "positive part", a typical choice)
        x = self.fc2( x )     # Second linear transformation
        # Really, the softmax is the classification label, but for numerical stability,
        # all computations are made in the log domain
        return F.log_softmax(x) 

two_fc_classifier = TwoFullNet(100)
if MU.use_cuda : two_fc_classifier.cuda()
In [5]:
evaluate_model(two_fc_classifier)
Time Elapsed:    1.13s, Average test loss: 2.3130, Test accuracy: 1040/10000 (10%)
Time Elapsed:   60.22s, Average test loss: 0.9905, Test accuracy: 6368/10000 (64%)
Time Elapsed:  122.86s, Average test loss: 0.8792, Test accuracy: 6739/10000 (67%)
Time Elapsed:  183.52s, Average test loss: 0.8570, Test accuracy: 6716/10000 (67%)
Time Elapsed:  246.59s, Average test loss: 0.8228, Test accuracy: 6899/10000 (69%)
Time Elapsed:  307.37s, Average test loss: 0.8063, Test accuracy: 6926/10000 (69%)
Time Elapsed:  372.35s, Average test loss: 0.7800, Test accuracy: 7042/10000 (70%)
Time Elapsed:  437.89s, Average test loss: 0.7711, Test accuracy: 6974/10000 (70%)
Time Elapsed:  504.42s, Average test loss: 0.7779, Test accuracy: 7020/10000 (70%)
Time Elapsed:  576.92s, Average test loss: 0.7579, Test accuracy: 7102/10000 (71%)
Time Elapsed:  642.29s, Average test loss: 0.7361, Test accuracy: 7247/10000 (72%)
Confusion matrix, without normalization

Performances are considerably lower than in the centered case. Was this expected?

In [6]:
synt_img = generate_image(two_fc_classifier, mode = "class", target_class = 2, seed = 'zero')

fig = plt.figure() ; plt.imshow(synt_img, interpolation='nearest', cmap='gray') ; 
plt.tight_layout() ; fig.canvas.draw()
Loss -0.00
Probas per class :  0%, 0%, 100%, 0%, 0%, 0%, 0%, 0%, 0%, 0%
In [7]:
nrow = 10
images = [torch.Tensor(generate_image(two_fc_classifier, mode = "neuron", 
                          target_operator = two_fc_classifier.fc1, target_neuron = (i,), 
                          seed = 'zero', verbose=False)).expand(3,MU.imgsize[0],MU.imgsize[1])
          for i in range(100)]
show(utils.make_grid( list(images), nrow=nrow, normalize=True))

Quite remarkably, the first operator of our model has converged towards a DiscreteCosinus-like Transform, made out of stripes! This makes sense, as a linear and translation-invariant operator is necessarily diagonal in the Fourier basis: even though it relies on non-linearities, our network has converged towards a kind of "spectral" translation-invariant classifier.

In [8]:
# By the way : parameters of the second layer can be accessed too
print(two_fc_classifier.fc2.weight)
print(two_fc_classifier.fc2.bias)
Parameter containing:
-0.0122  0.1190 -0.1480  ...  -0.0365 -0.0021 -0.0311
 0.1685 -0.1723  0.1454  ...  -0.0200  0.0284 -0.0725
-0.0421  0.0799 -0.2029  ...  -0.0703 -0.0107 -0.0897
          ...             ⋱             ...          
 0.0512 -0.0024  0.0640  ...   0.0071  0.0012 -0.0334
 0.1307  0.1180 -0.0262  ...  -0.1511 -0.0230  0.0321
 0.0602 -0.0110  0.1477  ...   0.1262 -0.0155 -0.0445
[torch.cuda.FloatTensor of size 10x100 (GPU 0)]

Parameter containing:
 0.0829
-0.0661
 0.0475
 0.0567
-0.0780
 0.0654
 0.1490
-0.1175
 0.0742
-0.0958
[torch.cuda.FloatTensor of size 10 (GPU 0)]

Enforcing translation invariance : Convolutional Neural Networks

Expecting a generic perceptron to give perfect results was too optimistic: as it is theoretically able to emulate any classifier, it is prone to overfit the training data. In a sense, we already encountered such a problem in the Wavelet Thresholding Numerical Tour: http://nbviewer.jupyter.org/github/gpeyre/numerical-tours/blob/master/python/denoisingwav_2_wavelet_2d.ipynb

Replacing the orthogonal wavelet transform with a translation-invariant transform (using cycle-spinning or the algorithme Ă  trous) dramatically increased the robustness of wavelet-based denoising algorithms; just the same, enforcing translation invariance in perceptrons will be a crucial step in the design of trainable operators for image processing. Fortunately, this prior is easy to enforce: we know that a translation-invariant linear operator can necessarily be represented as a convolution operator. It is thus natural to replace the generic linear operators known as "Fully connected layers" by their translation invariant counterparts, encoded as sets of convolution filters.

Furthermore, on top of translation invariance, we also know (or assume...) that natural images have a multiscale structure: there are relevant features at every scale, which are built as combinations of smaller details (edges -> eyes -> face). In practice, this means that we should prefer architectures with deep cascades of small convolution filters. Just like in a wavelet transform :-)

When designing a "neural network" (trainable transform) for image processing tasks, one thus typically restricts itself to a cascade of:

  • Convolution operators such as nn.Conv2d.
  • Su(b-p)sampling ("(un)pooling") operators such as F.max_pool2d.
  • Pointwise operations such as F.relu.

(On top of this, several utility operators (batch normalization, dropout...) have also been developed but we won't detail them here.)

In [9]:
class TwoConvTwoFullNet(nn.Module) :
    """
    Implements a trainable model which is the concatenation
    of two convolutional layers + two fully connected layers.
    The choice of architecture here was mostly done at random,
    for illustrative purpose...
    """
    def __init__(self) :
        super(TwoConvTwoFullNet, self).__init__()
        # First conv operator : 30 1x5x5-filters + 30 bias constants 
        # which map an image of size WxH to a 3D volume of size 30xWxH
        # (modulo a padding argument)
        self.conv1      = nn.Conv2d( 1, 30, kernel_size=5)
        # Second conv operator : 30 10x5x5-filters + 30 bias constants
        # which map a 3D volume of size 30xWxH to a 3D volume of size 30xWxH
        # (modulo a padding argument)
        self.conv2      = nn.Conv2d(30, 30, kernel_size=5, groups=6)
        # Dropout layer : probabilistic regularization trick
        self.conv2_drop = nn.Dropout2d()
        # Linear (i.e. fully connected) layer, a matrix of size (30*11*11)x100
        self.fc1        = nn.Linear(30*11*11, 100)
        # Linear (i.e. fully connected) layer, a matrix of size 100x10 (10 classes as output)
        self.fc2        = nn.Linear( 100, 10)

    def forward(self, x) :
        "Stacks up the network layers, with the simplistic relu nonlinearity in-between."
        x = F.max_pool2d(F.relu(                self.conv1(x)),  2)
        x = F.max_pool2d(F.relu(self.conv2_drop(self.conv2(x))), 2)
        # Up to this point, the treatment of x has been roughly translation-invariant:
        # Conv2d operators and ReLu nonlinearities are completely T-I,
        # whereas the subsampling "max_pool2d" operators are 
        # As we believe that the large-scale information should not be completely
        # discarded (some features such as heels just happen to always be located in the bottom 
        # right corners of our images...), we end our pipeline (transform) with 
        # a regular two-layers perceptrons that processes the reduced image x
        # as a features vector.
        
        # At this point, x is a 3D volume of size 30xWxH.
        # Due to convolution truncatures and subsamplings, 
        # we have W=H=11, so the following operation...
        x = x.view(-1, 30*11*11)    # Turns it into a vector
        x = F.relu(   self.fc1(x))  # 1x100 vector
        x = F.dropout(x, training=self.training) # Add a dropout pass during training only
        x = self.fc2( x)            # 1x10 vector
        # Really, the softmax is the classification label, but for numerical stability,
        # all computations are made in the log domain
        return F.log_softmax(x) 

two_layers_classifier = TwoConvTwoFullNet()
if MU.use_cuda : two_layers_classifier.cuda()
In [10]:
evaluate_model(two_layers_classifier)
Time Elapsed:    3.67s, Average test loss: 2.3036, Test accuracy: 1000/10000 (10%)
Time Elapsed:   87.31s, Average test loss: 0.8559, Test accuracy: 6989/10000 (70%)
Time Elapsed:  167.18s, Average test loss: 0.7529, Test accuracy: 7263/10000 (73%)
Time Elapsed:  253.18s, Average test loss: 0.6918, Test accuracy: 7332/10000 (73%)
Time Elapsed:  342.32s, Average test loss: 0.6475, Test accuracy: 7545/10000 (75%)
Time Elapsed:  428.71s, Average test loss: 0.6335, Test accuracy: 7592/10000 (76%)
Time Elapsed:  515.40s, Average test loss: 0.6093, Test accuracy: 7686/10000 (77%)
Time Elapsed:  603.44s, Average test loss: 0.5891, Test accuracy: 7788/10000 (78%)
Time Elapsed:  694.96s, Average test loss: 0.5773, Test accuracy: 7838/10000 (78%)
Time Elapsed:  789.61s, Average test loss: 0.5705, Test accuracy: 7850/10000 (78%)
Time Elapsed:  878.79s, Average test loss: 0.5526, Test accuracy: 7890/10000 (79%)
Confusion matrix, without normalization

By putting a regularizing prior into our algorithm-fitting method, we were able to achieve much better results. Please play around with the following visualization routines!

In [11]:
synt_img = generate_image(two_layers_classifier, mode = "class", target_class = 8, seed = 'zero')

fig = plt.figure() ; plt.imshow(synt_img, interpolation='nearest', cmap='gray') ; 
plt.tight_layout() ; fig.canvas.draw()
Loss -0.00
Probas per class :  0%, 0%, 0%, 0%, 0%, 0%, 0%, 0%, 100%, 0%
In [12]:
synt_img = generate_image(two_layers_classifier, mode = "neuron", 
                          target_operator = two_layers_classifier.fc1, target_neuron = (1,), 
                          seed = 'zero')

fig = plt.figure() ; plt.imshow(synt_img, interpolation='nearest', cmap='gray') ; 
plt.tight_layout() ; fig.canvas.draw()
Loss -25957.58
Probas per class :  0%, 0%, 0%, 0%, 0%, 100%, 0%, 0%, 0%, 0%
In [13]:
# Display images which excite the neuron from the first convolution operator,
# i-th layer, at position (1,2).
nrow = 5
images = [torch.Tensor(generate_image(two_layers_classifier, mode = "neuron", 
                          target_operator = two_layers_classifier.conv1, target_neuron = (i,1,2), 
                          seed = 'zero', verbose=False)).expand(3,MU.imgsize[0],MU.imgsize[1])[:,0:10,0:10]
          for i in range(30)]
show(utils.make_grid( list(images), nrow=nrow, normalize=True))
In [14]:
# Display images which excite the neuron from the second convolution operator,
# i-th layer, at position (1,3). As the dependance of the second layer
# with respect to the input image is not linear, we cannot really identify
# those images with "filters". But this map is as close as it gets to 
# the "collection of filters" used by the neural network to extract relevant
# features from the input images.
nrow = 5
images = [torch.Tensor(generate_image(two_layers_classifier, mode = "neuron", 
                          target_operator = two_layers_classifier.conv2, target_neuron = (i,1,3), 
                          seed = 'zero', verbose=False)).expand(3,MU.imgsize[0],MU.imgsize[1])[:,0:20,0:20]
          for i in range(30)]
# N.B.: we scale each pseudo-filter independently to get an image with full color dynamic
show(utils.make_grid( list(images), nrow=nrow, normalize=True, scale_each=True))
In [15]:
synt_img = generate_image(two_layers_classifier, mode = "neuron", 
                          target_operator = two_layers_classifier.conv2, target_neuron = (5,1,5), 
                          seed = 'zero')

fig = plt.figure() ; plt.imshow(synt_img, interpolation='nearest', cmap='gray') ; 
plt.tight_layout() ; fig.canvas.draw()
Loss -1915.26
Probas per class :  0%, 0%, 0%, 0%, 0%, 0%, 0%, 0%, 100%, 0%

Trying with a deeper net

As everyone who actually trained these kinds of model can tell: Deeper is better. That is, stacking layers empirically increases accuracy and reduces overfit. As of 2017, this regularizing effect of "Deep" networks is not understood (some people try to explain it through statistical physics analogies, but I don't know what it's worth)... Which hasn't prevented the GPU arms race to take place: researchers now work routinely with tens or hundreds of stacked convolution layers in a single model.

To illustrate on your machines this "Copy-Paste" philosophy which produces excellent results in most Computer Vision tasks, a simple four-layer architecture is presented below:

In [16]:
class FourConvTwoFullNet(nn.Module) :
    """
    Implements a trainable model which is the concatenation
    of four convolutional layers + two fully connected layers.
    """
    def __init__(self) :
        super(FourConvTwoFullNet, self).__init__()
        self.conv1      = nn.Conv2d( 1, 30, kernel_size=5, padding=2)
        self.conv2      = nn.Conv2d(30, 30, kernel_size=5, padding=2)
        self.conv3      = nn.Conv2d(30, 30, kernel_size=5, padding=2)
        self.conv4      = nn.Conv2d(30, 30, kernel_size=5, padding=2)
        self.fc1        = nn.Linear(30*3*3, 100)
        self.fc2        = nn.Linear( 100, 10)

    def forward(self, x) :
        "Stacks up the network layers, with the simplistic relu nonlinearity in-between."
        x = F.relu(F.max_pool2d( self.conv1(x),  2 ))
        x = F.relu(F.max_pool2d( self.conv2(x),  2 ))
        x = F.relu(F.max_pool2d( self.conv3(x),  2 ))
        x = F.relu(F.max_pool2d( self.conv4(x),  2 ))
        x = x.view(-1, 30*3*3)
        x = F.relu(   self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2( x)
        return F.log_softmax(x) 

four_layers_classifier = FourConvTwoFullNet()
if MU.use_cuda : four_layers_classifier.cuda()
In [17]:
evaluate_model(four_layers_classifier)
Time Elapsed:    1.73s, Average test loss: 2.3041, Test accuracy: 1018/10000 (10%)
Time Elapsed:   83.08s, Average test loss: 0.7425, Test accuracy: 7280/10000 (73%)
Time Elapsed:  163.95s, Average test loss: 0.6118, Test accuracy: 7638/10000 (76%)
Time Elapsed:  248.57s, Average test loss: 0.6106, Test accuracy: 7636/10000 (76%)
Time Elapsed:  336.07s, Average test loss: 0.5303, Test accuracy: 7941/10000 (79%)
Time Elapsed:  427.45s, Average test loss: 0.4737, Test accuracy: 8217/10000 (82%)
Time Elapsed:  514.75s, Average test loss: 0.4222, Test accuracy: 8413/10000 (84%)
Time Elapsed:  601.90s, Average test loss: 0.4199, Test accuracy: 8460/10000 (85%)
Time Elapsed:  691.13s, Average test loss: 0.3868, Test accuracy: 8585/10000 (86%)
Time Elapsed:  778.76s, Average test loss: 0.3668, Test accuracy: 8616/10000 (86%)
Time Elapsed:  864.11s, Average test loss: 0.3324, Test accuracy: 8774/10000 (88%)
Confusion matrix, without normalization
In [18]:
synt_img = generate_image(four_layers_classifier, mode = "class", target_class = 0, seed = 'random')

fig = plt.figure() ; plt.imshow(synt_img, interpolation='nearest', cmap='gray') ; 
plt.tight_layout() ; fig.canvas.draw()
Loss 0.00
Probas per class :  100%, 0%, 0%, 0%, 0%, 0%, 0%, 0%, 0%, 0%
In [19]:
# N.B.: we've got a very crude way of generating those images, which does not guarantee convergence...
nrow = 5
images = [torch.Tensor(generate_image(four_layers_classifier, mode = "neuron", 
                          target_operator = four_layers_classifier.conv2, target_neuron = (i,4,4), 
                          seed = 'random', verbose=False, lr=.5, nits=500)).expand(3,MU.imgsize[0],MU.imgsize[1])[:,0:20,0:20]
          for i in range(5)]
# N.B.: we scale each pseudo-filter independently to get an image with full color dynamic
show(utils.make_grid( list(images), nrow=nrow, normalize=True, scale_each=True))

Going further

At this point, one could be tempted to see CNNs as architectures which intelligently extract the best out of any large enough database. Indeed, fascinated by the beautiful "DeepDreams" or "DeepArt" visualizations that were recently published, people tend to fall into the "Pygmalion trap" of attributing to good-looking results human-like qualities. As it is a hot "media" topic at the moment, here are a few papers and links I would recommend "to get back to earth":

  • CNNs do not "see" the world with human eyes: Deep neural networks are easily fooled: High confidence predictions for unrecognizable images, Nguyen, Nosinski, Clune, 2015 : http://www.evolvingai.org/fooling
  • Prior to any training, the implicit image regularizing prior encoded in the convolutional architecture + L2 gradient descent scheme is already extremely strong: Deep Image Prior, Ulyanov, Vendaldi, Lempitsky, 2017 : https://dmitryulyanov.github.io/deep_image_prior
  • Some invariance properties of CNNs can be understood using a wavelet-like model: Understanding Deep Convolutional Networks, Mallat, 2016 : https://arxiv.org/abs/1601.04920

In this course, Gabriel and myself presented to you a brief introduction to Data Science and Deep Learning from a geometric, analytical point of view. But you shouldn't take our word on it! To help you see this moving field without mathematically-tinted glasses, a few websites which are definitely worth reading:

  • The course notes of Andrej Karpathy from Stanford, http://cs231n.github.io/ . This is the reference introduction to the subject from the "Computer Science" point of view. His post on Recurrent Neural Networks (which are to Hidden Markov Models what CNNs are to Wavelet Transforms) can also help you to get a grasp of the expressivity (or lack thereof) of the models currently used in natural language processing: http://karpathy.github.io/2015/05/21/rnn-effectiveness/
  • The blog of Chris Olah (+ everything on Distill), http://colah.github.io/ . In my opinion, his statements and visualizations are sometimes a bit over-optimistic / brush things under the carpet, but he has put a lot of high-quality work into producing accessible material on the subject. Have a look!
  • The blog of Ferenc Huszár, http://www.inference.vc/ with a refreshing Bayesian + Information Theoretic point of view on basically everyhing ;-)

What should you take back from all this?

In my opinion, these are three main points you should remember from today's session:

The generic Neural Network problem is irrelevant. You cannot just stack up a bunch of generic linear operators and hope for the best: if a model can express everything, it can also overfit anything. To put prior in their algorithms, researchers restrict themselves to carefully chosen architectures (submanifolds in the space of programs/transforms, if you want) and use optimizing techniques which empirically provide good generalization properties outside of the training set (i.e. little overfit). Optimizing a neural network with an "isotropic" $L^2$ gradient is very significant, as it heavily implies that the coordinates used to represent the model (i.e. the neural weights) should be "decorrelated" and "of the same scale". In general, gradient descent does not converge towards the global optimum of the cost functional, or even towards a genuine local minimum. (N.B.: Heavy ball, BFGS or other order-1 minimization methods don't really change this, just like they can't automagically recover the "high frequency" dimensions that one loses when working with blurred signals.)

CNNs are great, but definitely not "intelligent" and cannot be expected to perform well on non-image/audio data. You'll get a much clearer understanding of their limits if you see them as "beefed up" wavelet transforms. As of today, neural networks have proved their worth in "only" two fields: signal processing (where CNNs succeed Wavelet Transforms) and natural language processing (where RNNs succeed Hidden Markov Models). As the "neural net" versions of the classical models share the same algorithmic structure as their predecessors, they are subject to the same kind of pitfalls. (For instance, in the case of RNNs: a structural inability to generate purposeful sentences or paragraphs.)

You can now agressively optimize the parameters of your favorite data flow. If we leave the pseudo-physiological justifications aside, what really sticks out of the current research on neural networks is the development of automatic differentiation frameworks. In the last few years, the MILA, Google and Facebook (mainly), have put in a considerable engineering effort to develop easy-to-use and scalable development toolboxes such as TensorFlow and PyTorch. Now, this may seem surprising... but in my opinion, this low-level work is the most far-reaching component of the current research effort on neural networks. A genuine revolution for many applied fields, and not just Computer Vision!

Up to very recently, the only way to improve existent algorithms was to think about your problem very hard, try a few long shot ideas and hopefully come up with better results. But these efficient autodiff libraries have opened up a new path: the large-scale tuning of the thousands of parameters defined by your abstract formula/theory/computational graph. This means that as researchers, we now have the tools to leave the reassuring shores of "fully understood programs" and actually venture towards the wide world of real, non-mathematically formulated problems.

Understanding which parts of our "mathematical" data flows are crucial (translation invariance, multiscale priors...), and which part can be "freely optimized" (the actual filter coefficients...) is one of the major challenges that awaits applied mathematicians in the coming years. If I had to make a far-fetched architectural analogy, I'd say that traditional mathematical theories are comparable to stone, steel and wood: highly structured materials that can produce lasting monuments, but require a skilled workforce and have their intrinsic limitations. Pure data on the other hand is a bit like concrete: an amorph mass which was very hard to deploy in large scale applications... Until the development of prestressed concrete in 1928!

The development of Python+Autodiff+GPU frameworks, which provide researchers with a simple way of leveraging supervised datasets for their own specialized workflows, has the potential to be a turning point in the history of many applied maths fields. As we head towards hybrid, "mathematically structured" + "data driven" models, we may be able to engage more easily with our colleagues from other fields, on top of seeing a new class of "meta" problems arise.

Now, I may be wrong... But whatever the outcome of the journey, there's exciting times ahead!

Bonus tracks

Due to the training time of the models, I highly doubt you'll be able to read those lines before the end of the workshop session... But, just in case, please find below a few extensions to this workshop session.

When CNNs meet wavelets

Constraining a fully-connected neural networks to be (quasi) translation-invariant, we ended up with a data flow that iterates convolutions with small filters, pointwise nonlinearities and subsampling operations. This very much looks like a wavelet transform (with complex norms at every scale), where filters are optimized with respect to a given task instead of being chosen for their mathematical properties. Thanks to the current hardware (GPUs) and software (efficient and easy-to-use autodiffs libraries) revolution, this kind of fine-tuning of the models' parameters (say, filter coefficients) is now a realistic task. But do we really need to optimize every single filter from scratch ?

Not always. In some cases indeed, the first filters of a typical CNN converge to wavelet-like features, which makes sense even from a physiological point of view: vision relies on edge detectors, which quotient out local illumination changes. Hence, we can "help" the training of our CNN by replacing its first layers with a hard-coded wavelet-like operator, the scattering transform. To toy around this idea, you can use the code below. Beware: the underlying routines have not been optimized on CPU, and can thus be veeeeeery slow if you don't own an Nvidia GPU...

In [ ]:
# Download the toolbox: https://github.com/edouardoyallon/pyscatwave
from scatwave.scattering import Scattering 
scat = Scattering(M=MU.imgsize[0]+8, N=MU.imgsize[1]+8, J=4, jit=True)
if MU.use_cuda : scat = scat.cuda()
In [ ]:
class ScatteringFullNet(nn.Module) :
    """
    Implements a trainable model which is the concatenation
    of a scattering transform + two fully connected layers.
    """
    def __init__(self) :
        super(ScatteringFullNet, self).__init__()
        self.pad        = nn.ZeroPad2d(4)
        self.conv1      = nn.Conv2d(417, 64, kernel_size=3, padding=1)
        #self.conv2      = nn.Conv2d(30,  30, kernel_size=5, padding=5)
        #self.conv3      = nn.Conv2d(30,  30, kernel_size=5, padding=5)
        self.fc1        = nn.Linear(64*2*2, 100)
        self.fc2        = nn.Linear( 100, 10)

    def forward(self, x) :
        "Stacks up the network layers, with the simplistic relu nonlinearity in-between."
        x = self.pad(x)
        x = Variable(scat(x.data).squeeze(1))
        #print(x.size())
        x = F.relu(F.max_pool2d( self.conv1(x),  2 ))
        #x = F.relu(F.max_pool2d( self.conv2(x),  2 ))
        #x = F.relu(F.max_pool2d( self.conv3(x),  2 ))
        #print(x.size())
        x = x.view(-1, 64*2*2)
        x = F.relu(   self.fc1(x) )
        x = F.dropout(x, training=self.training)
        x = self.fc2( x)
        return F.log_softmax(x) 

scattering_classifier = ScatteringFullNet()
if MU.use_cuda : scattering_classifier.cuda()
In [ ]:
evaluate_model(scattering_classifier)

Transfer learning: using a pre-trained transform

Convolutional Neural Networks are nothing but finely tuned non-linear transforms. So why should we retrain them every time? In practice, most researchers and engineers contend themselves with a standard pre-trained network, and build custom applications on top! In this community, this re-use of neural weights is known as Transfer learning. You can read more about it at the following address: http://cs231n.github.io/transfer-learning/

Then, if your computer is fast enough, you can try to play around in the sandbox below :-)

In [ ]:
MU.imgsize = (56,56)
MU.args.batch_size      = 16
MU.args.test_batch_size = 16

def bw_to_rgb(tens) :
    return tens.expand(3,224,224)

MU.train_dataset = datasets.MNIST(DATASET, train=True, download=True,  # Use the training-MNIST dataset
                       transform=transforms.Compose([
                       transforms.RandomCrop(MU.imgsize, padding=28),
                       transforms.Scale(224),
                       transforms.ToTensor(),
                       transforms.Lambda(bw_to_rgb),
                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
                   ]))
MU.test_dataset  = datasets.MNIST(DATASET, train=False, transform=transforms.Compose([
                       transforms.RandomCrop(MU.imgsize, padding=28),
                       transforms.Scale(224),
                       transforms.ToTensor(),
                       transforms.Lambda(bw_to_rgb),
                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
                   ]))
# Iterator for the training pass:
MU.train_loader = torch.utils.data.DataLoader( MU.train_dataset,
    batch_size=MU.args.batch_size, shuffle=True, # Data is drawn as randomized minibatches
    **kwargs)                                 # Practical settings for parallel loading

# Iterator for the testing pass, with the same settings:
MU.test_loader  = torch.utils.data.DataLoader( MU.test_dataset,
    batch_size=MU.args.test_batch_size, shuffle=True, 
    **kwargs)
In [ ]:
# Display a few samples from the dataset.
# N.B.: since we've normalized the samples, this looks pretty ugly...

nrow = 4
images = [[] for i in range(10)]
for i in range(10000) :
    images[MU.train_dataset[i][1]].append( MU.train_dataset[i][0] )
    full = [len(l) >= nrow for l in images]
    if all(full) :
        break
show(utils.make_grid( list(itertools.chain(*[[ images[i][j] for j in range(nrow)] for i in range(10)])), nrow=nrow, normalize=True))
In [ ]:
from torchvision import models

transfer_model    = models.resnet18(pretrained=True)
for param in transfer_model.parameters():
    param.requires_grad = False
    
num_ftrs          = transfer_model.fc.in_features
transfer_model.fc = nn.Linear(num_ftrs, 10)

print([i.requires_grad for i in transfer_model.parameters()])

if MU.use_cuda : transfer_model = transfer_model.cuda()
In [ ]:
evaluate_model(transfer_model)