> 文档中心 > 一文理解深度学习框架中的InstanceNorm

一文理解深度学习框架中的InstanceNorm

91416067f573651e937cb82111718ce0.png

撰文|梁德澎

本文首发于公众号GiantPandaCV

本文主要推导 InstanceNorm 关于输入和参数的梯度公式,同时还会结合 PyTorch 和 MXNet 里的 InstanceNorm 代码来分析。

1

InstanceNorm 与 BatchNorm 的联系

对一个形状为 (N, C, H, W) 的张量应用 InstanceNorm[4] 操作,其实等价于先把该张量 reshape 为 (1, N * C, H, W)的张量,然后应用 BatchNorm[5] 操作。而 gamma 和 beta 参数的每个通道所对应输入张量的位置都是一致的。

而 InstanceNorm 与 BatchNorm 不同的地方在于:

  • InstanceNorm 训练与预测阶段行为一致,都是利用当前 batch 的均值和方差计算

  • BatchNorm 训练阶段利用当前 batch 的均值和方差,测试阶段则利用训练阶段通过移动平均统计的均值和方差

论文[6]中的一张示意图,就很好地解释了两者的联系:

d9d0599076cfccca2b66494d678c547e.png

https://arxiv.org/pdf/1803.08494.pdf
所以 InstanceNorm 对于输入梯度和参数求导过程与 BatchNorm 类似,下面开始进入正题。

2

梯度推导过程详解

在开始推导梯度公式之前,首先约定输入,参数,输出等符号:

  • 输入张量 , 形状为(N, C, H, W),rehape 为 (1, N * C, M) 其中 M=H*W

  • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)

  • 参数 ,形状为 (1, C, 1, 1),每个通道值对应 N*M 个输入,在计算的时候首先通过在第0维 repeat N 次再 reshape 成 (1, N*C, 1, 1)

而输入张量 reshape 成 (1, N * C, M)之后,每个通道上是一个长度为 M 的向量,这些向量之间的计算是不像干的,每个向量计算自己的 normalize 结果。所以求导也是各自独立。因此下面的均值、方差符号约定和求导也只关注于其中一个向量,其他通道上的向量计算都是一样的。

  • 一个向量上的均值 

  • 一个向量上的方差 

  • 一个向量上一个点的 normalize 中间输出 

  • 一个向量上一个点的 normalize 最终输出 ,其中  和  表示这个向量所对应的 gamma 和 beta 参数的通道值。

  • loss 函数的符号约定为 

gamma 和 beta 参数梯度的推导

先计算简单的部分,求 loss 对  和  的偏导:

95aaef6d2174952a72a4b5e4b829103a.png

其中  表示 gamma 和 beta 参数的第  个通道参与了哪些 batch 上向量的 normalize 计算。

因为 gamma 和 beta 上的每个通道的参数都参与了 N 个 batch 上 M 个元素 normalize 的计算,所以对每个通道进行求导的时候,需要把所有涉及到的位置的梯度都累加在一起。

对于  在具体实现的时候,就是对应输出梯度的值,也就是从上一层回传回来的梯度值。

输入梯度的推导

对输入梯度的求导是最复杂的,下面的推导都是求 loss 相对于输入张量上的一个点上的梯度,而因为上文已知,每个长度是 M 的向量的计算都是独立的,所以下文也是描述其中一个向量上一个点的梯度公式。具体是计算的时候,是通过向量操作(比如 numpy)来完成所有点的梯度计算。

先看 loss 函数对于  的求导:

3dfcb3547495b2c2d74f8ffcc1221a79.png

而从上文约定的公式可知,对于 

402 Payment Required

 的计算中涉及到  的有三部分,分别是 、 和 。所以 loss 对于 的偏导可以写成以下的形式:

ffeb06e4c1591a7fd67c1443be43ec7d.png

接下来就是,分别求上面式子最后三项的梯度公式。

第一项梯度推导

在求第一项的时候,把  和  看做常量,则有:

4b259226756d60b31a205cc0484fa49b.png

然后有:

cd761b1ddd9b1e84d9582b6a21240396.png

最后可得第一项梯度公式:

0e4e2b686304bdbf7a74945c9218c195.png

第三项梯度推导

接着先看第三项梯度eb4a67ad377f179d0b21872bee50f021.png,因为第三项的推导形式简单一些。

先计算上式最后一项 ,把  看做常量:

20c6d53002f4f6a1314a07fedf1f55f2.png

然后计算0cc9929c31f071f402c66ce4901c4344.png,等价于求 。而因为每个长度是 M 的向量都会计算一个方差 ,而计算出来的方差又会参数到所有 M 个元素的 normalize 的计算,所以 loss 对于  的偏导需要把所有 M 个位置的梯度累加,所以有:

bd69381bb6f078bf0aa251dd2f740512.png

接着计算 

402 Payment Required

2938c3a07a8cc80d9346f736facf881a.png

最后可得:

e632954a494cb40e144e968875ef569f.png

第二项梯度推导

最后计算第二项的梯度d19db8a071123e72cbc80ac1ed8d478f.png,一样先计算最后一项 :

7607cdd181d70e6ac97abd34db0e0f5a.png

接着计算17588e6689245345541d9d5539b33b56.png,等价于是求 。而因为每个长度是 M 的向量都会计算一个均值 ,而计算出来的均值又会参与到所有 M 个元素的 normalize 的计算,所以 loss 对于  的偏导需要把所有 M 个位置的梯度累加,所以有:

86a923f5f2890bec03ae67ff324fc238.png

接着计算 ,

847410632990f68e9f4521bfcd41a187.png

最后可得:

6f590faabdcfa82c5ef355fe58b6ff65.png

输入梯度最终的公式

分别计算完上面三项,就能得到对于输入张量每个位置上梯度的最终公式了:

338f98e1554667488bf105386a86dee5.png

观察上式可以发现,loss 对  的求导公式包括了 loss 对  求导的公式,所以这也是为什么先计算第三项的原因,在下面代码实现上也可以体现。

而在具体实现的时候就是直接套公式计算就可以了,下面来看下在 PyTroch 和 MXNet 框架中对 InstanceNorm 的实现。

3

深度学习框架实现代码解读

PyTroch 前向传播实现

前向传播代码链接:

https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L506

为了可读性简化了些代码:

Tensor instance_norm(    const Tensor& input,     const Tensor& weight/* optional */,     const Tensor& bias/* optional */,    const Tensor& running_mean/* optional */,     const Tensor& running_var/* optional */,    bool use_input_stats,     double momentum,     double eps,     bool cudnn_enabled) {  // ......  std::vector shape =     input.sizes().vec();  int64_t b = input.size(0);  int64_t c = input.size(1);  // shape 从 (b, c, h, w)  // 变为 (1, b*c, h, w)  shape[1] = b * c;  shape[0] = 1;  // repeat_if_defined 的解释见下文  Tensor weight_ =repeat_if_defined(weight, b);  Tensor bias_ =repeat_if_defined(bias, b);  Tensor running_mean_ =repeat_if_defined(running_mean, b);  Tensor running_var_ =repeat_if_defined(running_var, b);  // 改变输入张量的形状  auto input_reshaped =input.contiguous().view(shape);  // 计算实际调用的是 batchnorm 的实现  // 所以可以理解为什么 pytroch   // 前端 InstanceNorm2d 的接口  // 与 BatchNorm2d 的接口一样  auto out = at::batch_norm(    input_reshaped,     weight_, bias_,     running_mean_,     running_var_,    use_input_stats,     momentum,    eps, cudnn_enabled);  // ......  return out.view(input.sizes());}

repeat_if_defined 的代码:

https://github.com/pytorch/pytorch/blob/fa153184c8f70259337777a1fd1d803c7325f758/aten%2Fsrc%2FATen%2Fnative%2FNormalization.cpp#L27

static inline Tensor repeat_if_defined(  const Tensor& t,   int64_t repeat) {  if (t.defined()) {    // 把 tensor 按第0维度复制 repeat 次    return t.repeat(repeat);  }  return t;}

从 pytorch 前向传播的实现上看,验证了本文开头说的关于 InstanceNorm 与 BatchNorm 的联系。还有对于参数 gamma 与 beta 的处理方式。

MXNet 反向传播实现

因为我个人感觉 MXNet InstanceNorm 的反向传播实现很直观,所以选择解读其实现:

https://github.com/apache/incubator-mxnet/blob/4a7282f104590023d846f505527fd0d490b65509/src%2Foperator%2Finstance_norm-inl.h#L112

同样为了可读性简化了些代码:

templatevoid InstanceNormBackward(    const nnvm::NodeAttrs& attrs,    const OpContext &ctx,    const std::vector &inputs,    const std::vector &req,    const std::vector &outputs) {  using namespace mshadow;  using namespace mshadow::expr;  // ......  const InstanceNormParam& param =nnvm::get( attrs.parsed);  Stream *s =ctx.get_stream();  // 获取输入张量的形状  mxnet::TShape dshape =inputs[3].shape_;  // ......  int n = inputs[3].size(0);  int c = inputs[3].size(1);  // rest_dim 就等于上文的 M  int rest_dim =      static_cast( inputs[3].Size() / n / c);  Shape s2 = Shape2(n * c, rest_dim);  Shape s3 = Shape3(n, c, rest_dim);  // scale 就等于上文的 1/M  const real_t scale =static_cast(1) /  static_cast(rest_dim);  // 获取输入张量  Tensor data = inputs[3]   .get_with_shape(s2, s);  // 保存输入梯度  Tensor gdata = outputs[kData]   .get_with_shape(s2, s);  // 获取参数 gamma   Tensor gamma =      inputs[4].get(s);  // 保存参数 gamma 梯度计算结果  Tensor ggamma = outputs[kGamma]      .get(s);  // 保存参数 beta 梯度计算结果  Tensor gbeta = outputs[kBeta]      .get(s);  // 获取输出梯度  Tensor gout = inputs[0]      .get_with_shape( s2, s);  // 获取前向计算好的均值和方差  Tensor var =     inputs[2].FlatTo1D(s);  Tensor mean =     inputs[1].FlatTo1D(s);  // 临时空间  Tensor workspace = //.....  // 保存均值的梯度  Tensor gmean = workspace[0];  // 保存方差的梯度  Tensor gvar = workspace[1];  Tensor tmp = workspace[2];  // 计算方差的梯度,  // 对应上文输入梯度公式的第三项  // gout 对应输出梯度  gvar = sumall_except_dim(    (gout * broadcast(      reshape(repmat(gamma, n),Shape1(n * c)), data.shape_)) *      (data - broadcast( mean, data.shape_)) * -0.5f *      F( broadcast( var + param.eps, data.shape_),-1.5f)    );  // 计算均值的梯度,  // 对应上文输入梯度公式的第二项  gmean = sumall_except_dim(    gout * broadcast(      reshape(repmat(gamma, n),Shape1(n * c)), data.shape_));  gmean *=     -1.0f / F(      var + param.eps);  tmp = scale * sumall_except_dim( -2.0f * (data - broadcast(   mean, data.shape_)));  tmp *= gvar;  gmean += tmp;  // 计算 beta 的梯度  // 记得s3 = Shape3(n, c, rest_dim)  // 那么swapaxis(reshape(gout, s3))  // 就表示首先把输出梯度 reshape 成  // (n, c, rest_dim),接着交换第0和1维度  // (c, n, rest_dim),最后求除了第0维度  // 之外其他维度的和,  // 也就和 beta 的求导公式对应上了  Assign(gbeta, req[kBeta],    sumall_except_dim(swapaxis(reshape(gout, s3))));// 计算 gamma 的梯度  // swapaxis 的作用与上面 beta 一样  Assign(ggamma, req[kGamma],    sumall_except_dim(      swapaxis( reshape(gout * (data - broadcast(mean,   data.shape_))   / F(      broadcast(var + param.eps,data.shape_      )    ), s3 )      )    )  );  // 计算输入的梯度,  // 对应上文输入梯度公式三项的相加  Assign(gdata, req[kData],    (gout * broadcast(      reshape(repmat(gamma, n),Shape1(n * c)), data.shape_))      * broadcast(1.0f /F( var + param.eps), data.shape_)   + broadcast(gvar, data.shape_)* scale * 2.0f* (data - broadcast( mean, data.shape_))+ broadcast(gmean,data.shape_) * scale);}

可以看到基于 mshadow 模板库的反向传播实现,看起来很直观,基本是和公式能对应上的。

4

InstanceNorm numpy 实现

最后看下 InstanceNorm 前向计算与求输入梯度的 numpy 实现:

import numpy as npimport torcheps = 1e-05batch = 4channel = 2height = 32width = 32input = np.random.random(    size=(batch, channel, height, width)).astype(np.float32)# gamma 初始化为1# beta 初始化为0,所以忽略了gamma = np.ones((1, channel, 1, 1),     dtype=np.float32)# 随机生成输出梯度gout = np.random.random(    size=(batch, channel, height, width))\    .astype(np.float32)# 用numpy计算前向的结果mean_np = np.mean(  input, axis=(2, 3), keepdims=True)in_sub_mean = input - mean_npvar_np = np.mean(    np.square(in_sub_mean),axis=(2, 3), keepdims=True)invar_np = 1.0 / np.sqrt(var_np + eps)out_np = in_sub_mean * invar_np * gamma# 用numpy计算输入梯度scale = 1.0 / (height * width)# 对应输入梯度公式第三项gvar =   gout * gamma * in_sub_mean *   -0.5 * np.power(var_np + eps, -1.5)gvar = np.sum(gvar, axis=(2, 3),keepdims=True)# 对应输入梯度公式第二项gmean = np.sum(    gout * gamma,     axis=(2, 3), keepdims=True)gmean *= -invar_nptmp = scale * np.sum(-2.0 * in_sub_mean,axis=(2, 3), keepdims=True) gmean += tmp * gvar# 对应输入梯度公式三项之和gin_np =   gout * gamma * invar_np    + gvar * scale * 2.0 * in_sub_mean    + gmean * scale# pytorch 的实现p_input_tensor =   torch.tensor(input, requires_grad=True)trans = torch.nn.InstanceNorm2d(  channel, affine=True, eps=eps)p_output_tensor = trans(p_input_tensor)p_output_tensor.backward(  torch.Tensor(gout))# 与 pytorch 对比结果print(np.allclose(out_np,   p_output_tensor.detach().numpy(),   atol=1e-5))print(np.allclose(gin_np,   p_input_tensor.grad.numpy(),   atol=1e-5))# 命令行输出# True# True

本文对于 InstanceNorm 的梯度公式推导大部分参考了博客[1][2]的内容,然后在参考博客的基础上,按自己的理解具体推导了一遍,很多时候是从结果往回推,如果有什么疑惑或意见,欢迎交流。

参考资料

[1]https://medium.com/@drsealks/batch-normalisation-formulas-derivation-253df5b75220

[2]https://kevinzakka.github.io/2016/09/14/batch_normalization/

[3]https://www.zhihu.com/question/68730628

[4]https://arxiv.org/pdf/1607.08022.pdf

[5]https://arxiv.org/pdf/1502.03167v3.pdf

[6]https://arxiv.org/pdf/1803.08494.pdf

其他人都在看

  • 一个黑客“沦落”为搬砖的CVer

  • 岁末年初,为你打包了一份技术合订本

  • 一文轻松掌握深度学习框架中的einsum

  • 对抗软件系统复杂性:恰当分层,不多不少

  • 计算机史最疯狂一幕:“蓝色巨人”奋身一跃

  • 30年做成三家独角兽公司,AI芯片创业的底层逻辑

欢迎下载体验OneFlow新一代开源深度学习框架:https://github.com/Oneflow-Inc/oneflow/icon-default.png?t=M276https://github.com/Oneflow-Inc/oneflow/