SkipNet主要是以此假设出发,通过在传统CNN的每个layer(或module)上设置判断其是否需要执行的Gate module来决定是否需要真的执行此层计算,若判断为否则直接将activation feature maps传入到下一层,越过当下层的运算不做。无益这样做可以有效地节省传统CNN模型在部署时进行推理工作所需的时间。
就这样一旦训练好,SkipNet在做图片推理时可根据输入的feature maps不同灵活地决定是否执行某一网络中的层。下图可反映SkipNet这一根本特点。
对于每一层操作而言,SkipNet可表示为:xi+1 = GiFi(xi)+(1-Gi)xi。其中xi和Fi(xi)分别表示第ith layer的输入与输出feature maps;Gi ∈{0,1} 为第ith layer的Gate函数。
对于此处的Gate函数,作者实验了两种不同的表示方法。Paper中SkipNet基于的CNN网络为Resnet,其中Gate即可以被独立地添加在各个Residual block上面作为单独的个体,有着不同的参数即Feed-forward Gate;还可以所有的Residual blocks复用一个Gate module即Recurrent Gate。其不同之处可从下图中看出。
Gate module设计
作者在论文中共尝试了三种不同的Gate module设计,它们对计算与accuracy的考量略有不同。
FFGate-I: MaxPool(2x2) -> Conv(3x3, 1) -> Conv(3x3, 2) -> AvgPool -> FC,整体计算量约为Residual block的19%,在论文中主要用于较浅的一些网络(层数小于100);
FFGate-II: Conv(3x3, 2) -> AvgPool -> FC,整体计算量约为Residual block的12.5%,主要用于较深的一些网络(层数大于100);
RNNGate: AvgPool -> Conv(1x1) -> LSTM(10 hidden units) -> FC,整体计算量约为Residual block的0.04%,是论文中首选的Gate函数。在深层次网络中它相对于Feed-forward Gate有较大的性能与分类精度优势,只是在较浅的层次上它精度略低,但计算开销仍有较大优势。
下图为以上三种Gate module的概况描述。
使用Hybrid RL的Skipping policy学习
对于上节所介绍的Gate函数可理解为是这么一种决策:Π(xi,i) = P(Gi(xi) = gi),(其中gi∈{0,1},分别表示执行还是略过第ith层执行的两种离散决策)。
这样对于有N层的CNN来说,我们在forward时需要决定下如此一个输入为x的决策序列:g = [g1,....,gN] ˜ Π(F<sub>θ</sub>)。在这里Fθ = [Fθ1,....,FθN]表示CNN网络中N个layers的计算。
其中Ri = (1-gi)Ci表示的是每个Gate module所节省的计算,亦为它的激励函数。因为paper中用的是Resnet,故假定所有的Ci相同,设为1。然后α 则为CNN分类准确率与计算节省之间的平衡系数。可以看出这里的目标函数设计同时考虑了模型分类精度与计算效率并力图在其中寻找平衡。
下式为具体计算时的梯度计算公式。可以看出它主要由两部分组成,第一部分表示的是学习分类精度的supervised loss,第二部分则是要接合RL最终学习出来的反映计算节省的Skip learning policy。
下图为使用Hybrid RL的具体算法概述。
如下为FFGate-I的设计实现,其它Gate module的写法并无太多不同。
# Feedforward-Gate (FFGate-I)class FeedforwardGateI(nn.Module): """ Use Max Pooling First and then apply to multiple 2 conv layers. The first conv has stride = 1 and second has stride = 2""" def __init__(self, pool_size=5, channel=10): super(FeedforwardGateI, self).__init__() self.pool_size = pool_size self.channel = channel self.maxpool = nn.MaxPool2d(2) self.conv1 = conv3x3(channel, channel) self.bn1 = nn.BatchNorm2d(channel) self.relu1 = nn.ReLU(inplace=True) # adding another conv layer self.conv2 = conv3x3(channel, channel, stride=2) self.bn2 = nn.BatchNorm2d(channel) self.relu2 = nn.ReLU(inplace=True) pool_size = math.floor(pool_size/2) # for max pooling pool_size = math.floor(pool_size/2 + 0.5) # for conv stride = 2 self.avg_layer = nn.AvgPool2d(pool_size) self.linear_layer = nn.Conv2d(in_channels=channel, out_channels=2, kernel_size=1, stride=1) self.prob_layer = nn.Softmax() self.logprob = nn.LogSoftmax() def forward(self, x): x = self.maxpool(x) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.avg_layer(x) x = self.linear_layer(x).squeeze() softmax = self.prob_layer(x) logprob = self.logprob(x) # discretize output in forward pass. # use softmax gradients in backward pass x = (softmax[:, 1] > 0.5).float().detach() - \ softmax[:, 1].detach() + softmax[:, 1] x = x.view(x.size(0), 1, 1, 1) return x, logprob
下面这个class里面则具体实现了如何将Gate module与某一CNN网络结合起来从而实现相关的SkipNet。
class ResNetFeedForwardRL(nn.Module): """Adding gating module on every basic block""" def __init__(self, block, layers, num_classes=10, gate_type='ffgate1', **kwargs): self.inplanes = 16 super(ResNetFeedForwardRL, self).__init__() self.num_layers = layers self.conv1 = conv3x3(3, 16) self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.gate_instances = [] self.gate_type = gate_type self._make_group(block, 16, layers[0], group_id=1, gate_type=gate_type, pool_size=32) self._make_group(block, 32, layers[1], group_id=2, gate_type=gate_type, pool_size=16) self._make_group(block, 64, layers[2], group_id=3, gate_type=gate_type, pool_size=8) # remove the last gate instance, (not optimized) del self.gate_instances[-1] self.avgpool = nn.AvgPool2d(8) self.fc = nn.Linear(64 * block.expansion, num_classes) self.softmax = nn.Softmax() self.saved_actions = [] self.rewards = [] for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = m.weight.size(0) * m.weight.size(1) m.weight.data.normal_(0, math.sqrt(2. / n)) def _make_group(self, block, planes, layers, group_id=1, gate_type='fisher', pool_size=16): """ Create the whole group""" for i in range(layers): if group_id > 1 and i == 0: stride = 2 else: stride = 1 meta = self._make_layer_v2(block, planes, stride=stride, gate_type=gate_type, pool_size=pool_size) setattr(self, 'group{}_ds{}'.format(group_id, i), meta[0]) setattr(self, 'group{}_layer{}'.format(group_id, i), meta[1]) setattr(self, 'group{}_gate{}'.format(group_id, i), meta[2]) # add into gate instance collection self.gate_instances.append(meta[2]) def _make_layer_v2(self, block, planes, stride=1, gate_type='fisher', pool_size=16): """ create one block and optional a gate module """ downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layer = block(self.inplanes, planes, stride, downsample) self.inplanes = planes * block.expansion if gate_type == 'ffgate1': gate_layer = RLFeedforwardGateI(pool_size=pool_size, channel=planes*block.expansion) elif gate_type == 'ffgate2': gate_layer = RLFeedforwardGateII(pool_size=pool_size, channel=planes*block.expansion) else: gate_layer = None if downsample: return downsample, layer, gate_layer else: return None, layer, gate_layer def repackage_vars(self): self.saved_actions = repackage_hidden(self.saved_actions) def forward(self, x, reinforce=False): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) masks = [] gprobs = [] # must pass through the first layer in first group x = getattr(self, 'group1_layer0')(x) # gate takes the output of the current layer mask, gprob = getattr(self, 'group1_gate0')(x) gprobs.append(gprob) masks.append(mask.squeeze()) prev = x # input of next layer for g in range(3): for i in range(0 + int(g == 0), self.num_layers[g]): if getattr(self, 'group{}_ds{}'.format(g+1, i)) is not None: prev = getattr(self, 'group{}_ds{}'.format(g+1, i))(prev) x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x) # new mask is taking the current output prev = x = mask.expand_as(x) * x \ + (1 - mask).expand_as(prev) * prev mask, gprob = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) gprobs.append(gprob) masks.append(mask.squeeze()) del masks[-1] x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) # collect all actions for inst in self.gate_instances: self.saved_actions.append(inst.saved_action) if reinforce: # for pure RL softmax = self.softmax(x) action = softmax.multinomial() self.saved_actions.append(action) return x, masks, gprobs