最近在阅读论文《Network Pruning via Transformable Architecture Search》的源码,其主要实现了结构化自动裁剪神经网络的逻辑(也算是autoML的部分),由 pytorch 实现。特此记录。

1 先说几个值得学习的地方

  • 解耦。eg,网络层解耦;声明专门的类用于神经网络层的各种定义,包括网络层(输入输出),网络的连接,前向传播逻辑,优化器的定义等等。定义各个网络层之后,在 forwords 方法中构建网络。如下所示。
1
2
3
4
5
6
7
8
9
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits
  • 充分使用日志。这个是之前自己做的不好的部分,往往只会在调试时才开始加入全面的日志。

其实网络层的解耦是最最基础的,通过阅读 github 源码你会发现但凡是一些知名学者的代码都是进行良好的封装,实现各个模块解耦以便于后期快速增减各种逻辑,不过对于初学者来说阅读这样的代码就会感觉项目异常庞大,因为存在许多不相关的逻辑。某种程度上来说,过高的抽象封装可能会导致可阅读性的降低。

2 架构搜索思想分析

论文提出到的 NAS 架构搜索主要是搜索网络的宽度和深度。宽度指的是各层的 channal 数量(层输出),深度指的是网络的层数。论文着重介绍了如何求解各层最优的 channal。

3 重点剖析

论文的重点在于如何实现自动裁剪网络结构,核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()

x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )

下面一行一行进行分析。

3.1 line 1

先看下self.width_attentions是个什么东西,其相关定义在类SearchWidthCifarResNet中:

1
2
3
4
5
6
7
8
9
10
11
12
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)

self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)

这里,width_attentions是一个参数化的 tensor,而 tensor 的参数是(len(self.Ranges, get_choices(None) ))。从上述代码中可以看到,self.Ranges是累加了各个 layer的 get_range() ,所以 len(self.Ranges) 的值也就是网络层数。

再看下get_choices的实现,因为传入的 None,所以 tensor 的参数实际上是(网络层数, 8))。

1
2
3
4
5
6
7
8
9
10
11
from .SoftSelect import get_width_choices as get_choices
def get_width_choices(nOut):
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
if nOut is None:
return len(xsrange)
else:
Xs = [int(nOut * i) for i in xsrange]
#xs = [ int(nOut * i // 10) for i in range(2, 11)]
#Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
Xs = sorted( list( set(Xs) ) )
return tuple(Xs)

这里我们再看下self.Ranges,其通过layer.get_range()获得,而layer在类SearchWidthCifarResNet初始化的过程中定义:

1
2
3
4
5
6
7
8
9
10
11
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) # 初始化layers
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module ) # 增加 layer 到 layers
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)

可以看到,layer实际上是类ConvBNReLU,看看类ConvBNReLU是怎么实现get_range()的。

1
2
3
4
5
6
7
8
9
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
def get_range(self):
return [self.choices]

可以看到,get_range()的底层是通过函数get_choices实现的,通过参数nOut进行控制。参数nOut是类ConvBNReLU初始化的时候传入的。在ConvBNReLU初始化的过程中,nOut 表示当前网络层的输出维度。当 nOut != None时,get_width_choices()函数返回的是大小为nOut的 tuple。