最近在阅读论文《Network Pruning via Transformable Architecture Search》的源码,其主要实现了结构化自动裁剪神经网络的逻辑(也算是autoML的部分),由 pytorch 实现。特此记录。
1 先说几个值得学习的地方
解耦。eg,网络层解耦;声明专门的类用于神经网络层的各种定义,包括网络层(输入输出),网络的连接,前向传播逻辑,优化器的定义等等。定义各个网络层之后,在 forwords 方法中构建网络。如下所示。
1 2 3 4 5 6 7 8 9 def basic_forward (self, inputs) : if self.InShape is None : self.InShape = (inputs.size(-2 ), inputs.size(-1 )) x = inputs for i, layer in enumerate(self.layers): x = layer( x ) features = self.avgpool(x) features = features.view(features.size(0 ), -1 ) logits = self.classifier(features) return features, logits
充分使用日志。这个是之前自己做的不好的部分,往往只会在调试时才开始加入全面的日志。
其实网络层的解耦是最最基础的,通过阅读 github 源码你会发现但凡是一些知名学者的代码都是进行良好的封装,实现各个模块解耦以便于后期快速增减各种逻辑,不过对于初学者来说阅读这样的代码就会感觉项目异常庞大,因为存在许多不相关的逻辑。某种程度上来说,过高的抽象封装可能会导致可阅读性的降低。
2 架构搜索思想分析 论文提出到的 NAS 架构搜索主要是搜索网络的宽度和深度。宽度指的是各层的 channal 数量(层输出),深度指的是网络的层数。论文着重介绍了如何求解各层最优的 channal。
3 重点剖析 论文的重点在于如何实现自动裁剪网络结构,核心代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def search_forward (self, inputs) : flop_probs = nn.functional.softmax(self.width_attentions, dim=1 ) selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) with torch.no_grad(): selected_widths = selected_widths.cpu() x, last_channel_idx, expected_inC, flops = inputs, 0 , 3 , [] for i, layer in enumerate(self.layers): selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv] selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv] layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv] x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) last_channel_idx += layer.num_conv flops.append( expected_flop ) flops.append( expected_inC * (self.classifier.out_features*1.0 /1e6 ) ) features = self.avgpool(x) features = features.view(features.size(0 ), -1 ) logits = linear_forward(features, self.classifier) return logits, torch.stack( [sum(flops)] )
下面一行一行进行分析。
3.1 line 1 先看下self.width_attentions
是个什么东西,其相关定义在类SearchWidthCifarResNet
中:
1 2 3 4 5 6 7 8 9 10 11 12 self.Ranges = [] self.layer2indexRange = [] for i, layer in enumerate(self.layers): start_index = len(self.Ranges) self.Ranges += layer.get_range() self.layer2indexRange.append( (start_index, len(self.Ranges)) ) assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}' .format(len(self.Ranges) + 1 , depth)self.register_parameter('width_attentions' , nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None )))) nn.init.normal_(self.width_attentions, 0 , 0.01 ) self.apply(initialize_resnet)
这里,width_attentions是一个参数化的 tensor,而 tensor 的参数是(len(self.Ranges, get_choices(None) ))。从上述代码中可以看到,self.Ranges是累加了各个 layer的 get_range() ,所以 len(self.Ranges) 的值也就是网络层数。
再看下get_choices的实现,因为传入的 None,所以 tensor 的参数实际上是(网络层数, 8))。
1 2 3 4 5 6 7 8 9 10 11 from .SoftSelect import get_width_choices as get_choicesdef get_width_choices (nOut) : xsrange = [0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 ] if nOut is None : return len(xsrange) else : Xs = [int(nOut * i) for i in xsrange] Xs = sorted( list( set(Xs) ) ) return tuple(Xs)
这里我们再看下self.Ranges,其通过layer.get_range()获得,而layer在类SearchWidthCifarResNet
初始化的过程中定义:
1 2 3 4 5 6 7 8 9 10 11 self.layers = nn.ModuleList( [ ConvBNReLU(3 , 16 , 3 , 1 , 1 , False , has_avg=False , has_bn=True , has_relu=True ) ] ) self.InShape = None for stage in range(3 ): for iL in range(layer_blocks): iC = self.channels[-1 ] planes = 16 * (2 **stage) stride = 2 if stage > 0 and iL == 0 else 1 module = block(iC, planes, stride) self.channels.append( module.out_dim ) self.layers.append ( module ) self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}" .format(stage, iL, layer_blocks, len(self.layers)-1 , iC, module.out_dim, stride)
可以看到,layer实际上是类ConvBNReLU,看看类ConvBNReLU是怎么实现get_range()的。
1 2 3 4 5 6 7 8 9 class ConvBNReLU (nn.Module) : def __init__ (self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu) : super(ConvBNReLU, self).__init__() self.InShape = None self.OutShape = None self.choices = get_choices(nOut) self.register_buffer('choices_tensor' , torch.Tensor( self.choices )) def get_range (self) : return [self.choices]
可以看到,get_range()的底层是通过函数get_choices实现的,通过参数nOut进行控制。参数nOut是类ConvBNReLU初始化的时候传入的。在ConvBNReLU初始化的过程中,nOut 表示当前网络层的输出维度。当 nOut != None时,get_width_choices()函数返回的是大小为nOut的 tuple。