Jeremy

Spatial Transformer Networks 阅读笔记

最近开始看STN[1]相关的资料,希望可以复现论文《Robust Scene Text Recognition with Automatic Rectification》[2]中STN的效果,不过从目前实验的情况来看,加了STN之后的效果还真不怎么样。

论文阅读笔记

题目: Spatial Transformer Networks

作者: Max Jaderberg (Google DeepMind)

摘要:

Convolutional Neural Networks define an exceptionally powerful class of models, but are still limited by the lack of ability to be spatially invariant to the input data in a computationally and parameter efficient manner. In this work we introduce a new learnable module, the Spatial Transformer, which explicitly allows the spatial manipulation of data within the network. This differentiable module can be inserted into existing convolutional architectures, giving neural networks the ability to actively spatially transform feature maps, conditional on the feature map itself, without any extra training supervision or modification to the optimisation process. We show that the use of spatial transformers results in models which learn invariance to translation, scale, rotation and more generic warping, resulting in state-of-the-art performance on several benchmarks, and for a number of classes of transformations.

动机

目前,基于CNN的方法在classification、localisation、semantic segmentation、action recognition 等任务中,都刷到了state-of-the-art的性能。不过,CNN其实也不是完美的。

我们都知道,在CNN中有一个叫做(max、average)pooling的layer,这个layer为CNN提供了位移不变性的功能。但是这个位移不变性其实是个不完善的功能,它只能提供有限的位移不变性。因为pooling的kernel一般不会很大(eg. 2*2 pixels),而如果要提供较大的空间位移不变性,那么就需要通过比较深的pooling和convolutions的组合来实现,但是到底来说还是受限的,因为一个model的结构是确定的,也即意味着其所能提供的空间位移不变性是有范围的。

那么有什么办法可以提供各大范围的空间位移不变性呢?

熟悉视觉的同志,这个时候就会想起来,在几何中有一个仿射变换,它通过有限的6个参数就可以提供非常灵活的平移、旋转、缩放、裁剪的功能。

其中仿射变换矩阵如下所示:

平移变换的公式如下:

缩放变换的公式如下:

旋转变换的公式如下:

其中alpha为绕原点顺时针旋转的角度,另外,这里面还有一个trick,那就是图像的坐标不是中心坐标系,所以需要对坐标进行Normalization,把坐标的范围调整到 [-1, 1], 这样旋转就变成绕图像中心旋转了。

OK,那么有什么方法可以在CNN的网络中嵌入这个功能呢?是的,有办法,Max很聪明地想到通过回归网络来得到这些仿射变换矩阵的参数,再通过仿射变换的公式就得到了一个可以克服位移变换、旋转变换、缩放变换的图像。

然后,我们下一步来具体看看,STN是怎么操作的。

方法

STN is a differentiable module which applies a spatial transformation to a feature map during a single forward pass, where the transformation is conditioned on the particular input, producing a single output feature map. For multi-channel inputs, the same warping is applied to each channel.

STN mechanism主要可以分为三个部分,如下图所示:

  • The localization network: takes the original image as an input and outputs the parameters of the transformation we want to apply.
  • The grid generator: generates a grid of coordinates in the input image corresponding to each pixel from the output image.
  • The sampler: generates the output image using the grid given by the grid generator.

Localizatoin Network :

Localization Network可以总结为以下一个公式:

其中 U 是input feature map(W x H x C),theta 的 size 根据transformation的type决定,当为仿射变换的时候,theta是一个6-dimensional的向量。而f_loc()的具体组成并不限定,只要最后一层是一个final regression layer,其他的层可以是fully-connected network 或者 convolutional network.

Parameterised Sampling Grid :

通过LocNet获得theta的值后,我们需要建立pixel之间的映射函数,其实也就是仿射变换函数,假设U中每个像素坐标(xi_s, yi_s),V中每个像素的坐标(xi_t, yi_t),那么U、V中像素的对应几何关系如下:

在这一步,可能会有同学疑惑,为什么我们是算(xi_source, yi_source),而不是算(xi_target, yi_target)。

在这里,我们需要先明确一点,那就是xi,yi是坐标值,target的V的大小是我们一开始设定好的,最经常的默认值就是和U的大小一样。而V内部像素的具体值则是由U内部部分像素值来提供,比如下图所示,V的值由U中的特定位置的像素值计算得到。

因此,我们通过theta和Target Grid得到V中每个像素对应于U中特定位置的坐标,再通过插值就得到了最后的结果。

Differentiable Image Sampling :

通过 Parameterised Sampling Grid 获得T_theta(G)后,我们就需要考虑如何生成output feature map V 了。

其中T_theta(G)其实就是一系列(xi_s,yi_s)的值。

注意, 其实上面求Vi的公式中,我们并没有对这个U的H、W都进行遍历,代码实现中只取相邻的4个点。

优势

  • 无需标注关键点,通过backprop就可以完成对LocNet的训练,从而完成Rectification;
  • 仿射变换的参数是和输入图像相关的,由输入图像决定theta应该用什么值,即由输入图像自己决定如何对自己进行仿射变换;
  • 可以方便地嵌入已有的model当中;

缺陷

  • 并不好训练,有时候加了STN反而会使得结果更糟糕,这个时候可以尝试重新初始化整个model进行训练;
  • input feature map的size必须是固定的,因为求theta的LocNet是一个fixed的model,也即里面W的维数是对应着input feature map的dimension的;

Code

Note: 本文图表公式来自以下几个参考文献。

Refs


林建民-机器视觉
Blog地址:http://www.linjm.tech/
旧博客地址:http://blog.csdn.net/linj_m