pytorch Module介绍
之前没有学过 pytorch,最近在看 pytorch的代码时以 tf 的思维去看,很多 module 相关的内容看的似懂非懂,所以把 module 部分拿出来学习一下。先附上相关资源连接。
module 是 pytorch nn包下的一个类(class torch.nn.Module)。torch nn下包含了一些常见的网络结构供我们使用,其中,class torch.nn.Module是所有网络的基类。Modules
也可以包含其它Modules
,允许使用树结构嵌入它们,可以将子Module赋值给Module的属性。
1 torch.nn
首先看下torch.nn包里都有些什么东西。
可以看到,torch.nn 已经为我们实现了常见的 cnn 网络层、rnn 网络层,另外还实现了dropout 层、损失函数等。Parameters是一种tensor,可以被看做是 module 的参数。Parameters是tensor 的一个子类,它有一个特别的属性:当作为 module 的属性时,它会自动被加入到module 的参数列表(也就是parameters()迭代器中)。但是将tensor赋给module 属性则不会产生上述效果。这样做的原因是可能要在模型中缓存一些临时状态,如 RNN 最后一个隐状态。**如果没有 Parameters,那么这些临时变量也会注册为模型变量【这个没看懂】**。
需要注意的是,在 pytorch0.3.1版本及其之前,Parameters是 torch.autograd.Variable 的一个子类;而在高版本的 pytorch 中,Variable类已经被废弃,Autograd 自动支持 tensor 并设置 reguires_grad 为 True,下面是一些变化。
- Variable(tensor)和Variable(tensor, requires_grad)仍然可以使用,但是返回的不是 Variable不是tensor。
- Variable.data 和 tensor.data 是相同的。
- Variable.backward(), Variable.detach(), Variable.register_hook() 等方法可以在 tensor 上以相同的方法名使用。
- 可以创建 tensor 并设置 requires_grad为True。例如:
autograd_tensor=torch.randn((2,3,4),requires_grad=True)
。
再看下Parameter的实现:
1 | class Parameter(torch.Tensor): |
在 pytorch0.3.1之前存在 Tensor、Variable、Parameter 等名词,如果不清楚的话可以参考博客《Pytorch 中的 Tensor , Variable & Parameter》。
2 torch.nn.Module
在官方文档中,Module 被放在Containers下,我觉得可能是考虑到Module 类是最基本的类吧。Containers下主要是一些容器类,如下图所示。
上面这些容器类讲起来比较久,下次再介绍,这里大概知道有这些就好了。
下图是 torch.nn.Module 的所有方法。
如果要实现自己的module 的话,必须重写forward(*input)方法,网络的计算逻辑(将不同的网络层连接在一起)都在该函数中实现。而网络层通常在__init__
函数中实现。如下:
1 | import torch.nn as nn |