之前没有学过 pytorch,最近在看 pytorch的代码时以 tf 的思维去看,很多 module 相关的内容看的似懂非懂,所以把 module 部分拿出来学习一下。先附上相关资源连接。

module 是 pytorch nn包下的一个类(class torch.nn.Module)。torch nn下包含了一些常见的网络结构供我们使用,其中,class torch.nn.Module是所有网络的基类。Modules也可以包含其它Modules,允许使用树结构嵌入它们,可以将子Module赋值给Module的属性。

1 torch.nn

首先看下torch.nn包里都有些什么东西。

可以看到,torch.nn 已经为我们实现了常见的 cnn 网络层、rnn 网络层,另外还实现了dropout 层、损失函数等。Parameters是一种tensor,可以被看做是 module 的参数。Parameters是tensor 的一个子类,它有一个特别的属性:当作为 module 的属性时,它会自动被加入到module 的参数列表(也就是parameters()迭代器中)。但是将tensor赋给module 属性则不会产生上述效果。这样做的原因是可能要在模型中缓存一些临时状态,如 RNN 最后一个隐状态。**如果没有 Parameters,那么这些临时变量也会注册为模型变量【这个没看懂】**。

需要注意的是,在 pytorch0.3.1版本及其之前,Parameters是 torch.autograd.Variable 的一个子类;而在高版本的 pytorch 中,Variable类已经被废弃,Autograd 自动支持 tensor 并设置 reguires_grad 为 True,下面是一些变化

  • Variable(tensor)和Variable(tensor, requires_grad)仍然可以使用,但是返回的不是 Variable不是tensor。
  • Variable.data 和 tensor.data 是相同的。
  • Variable.backward(), Variable.detach(), Variable.register_hook() 等方法可以在 tensor 上以相同的方法名使用。
  • 可以创建 tensor 并设置 requires_grad为True。例如:autograd_tensor=torch.randn((2,3,4),requires_grad=True)

再看下Parameter的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Parameter(torch.Tensor):
"""
Arguments:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`excluding-subgraphs` for more details. Default: `True`
"""

def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result

def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()

def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, OrderedDict())
)

在 pytorch0.3.1之前存在 Tensor、Variable、Parameter 等名词,如果不清楚的话可以参考博客《Pytorch 中的 Tensor , Variable & Parameter》。

2 torch.nn.Module

在官方文档中,Module 被放在Containers下,我觉得可能是考虑到Module 类是最基本的类吧。Containers下主要是一些容器类,如下图所示。

上面这些容器类讲起来比较久,下次再介绍,这里大概知道有这些就好了。

下图是 torch.nn.Module 的所有方法。

如果要实现自己的module 的话,必须重写forward(*input)方法,网络的计算逻辑(将不同的网络层连接在一起)都在该函数中实现。而网络层通常在__init__函数中实现。如下:

1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)# submodule: Conv2d
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))

3 参考文献