Module存储了模块类的函数
pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么
class Module(object): def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend
class FunctionBackend(object): def __init__(self): self.function_classes = {} def register_function(self, name, function_class): self.function_classes[name] = function_class
其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1
RNN <function RNN at 0x7f4330534378> RNNTanhCell <function RNNTanhCell at 0x7f4330530d90> RNNReLUCell <function RNNReLUCell at 0x7f43305309d8> LSTMCell <function LSTMCell at 0x7f4330530e18> GRUCell <function GRUCell at 0x7f4330530ea0> Dropout <class 'torch.nn._functions.dropout.Dropout'>Dropout2d <class 'torch.nn._functions.dropout.FeatureDropout'>Dropout3d <class 'torch.nn._functions.dropout.FeatureDropout'>MarginCriterion <class 'torch.nn._functions.thnn.auto.MarginCriterion'>MarginCriterionBackward <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>GatedLinear <class 'torch.nn._functions.thnn.auto.GatedLinear'>GatedLinearBackward <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>SpatialFullConvolutionMap <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>SpatialFullConvolutionMapBackward <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>VolumetricFractionalMaxPooling <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>VolumetricFractionalMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>VolumetricFullDilatedConvolution <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>Col2Im <class 'torch.nn._functions.thnn.auto.Col2Im'>Col2ImBackward <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>DilatedConv2d <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>DilatedConv2dBackward <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>SpatialConvolutionLocal <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>SpatialConvolutionLocalBackward <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>FeatureLPPooling <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>FeatureLPPoolingBackward <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>VolumetricGridSamplerBilinear <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>VolumetricGridSamplerBilinearBackward <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>TemporalUpSamplingNearest <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>TemporalUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>SpatialUpSamplingNearest <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>SpatialUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>ReflectionPad1d <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>ReflectionPad1dBackward <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>SpatialConvolutionMap <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>SpatialConvolutionMapBackward <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>NLLLoss <class 'torch.nn._functions.thnn.auto.NLLLoss'>NLLLossBackward <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>Softplus <class 'torch.nn._functions.thnn.auto.Softplus'>SoftplusBackward <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>LogSigmoid <class 'torch.nn._functions.thnn.auto.LogSigmoid'>LogSigmoidBackward <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>SpatialUpSamplingBilinear <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>SpatialUpSamplingBilinearBackward <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>ReplicationPad3d <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>ReplicationPad3dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>MultiMarginLoss <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>MultiMarginLossBackward <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>ReplicationPad1d <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>ReplicationPad1dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>MultiLabelMarginLoss <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>MultiLabelMarginLossBackward <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>SpatialFullDilatedConvolution <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>SpatialFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>SoftMarginLoss <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>SoftMarginLossBackward <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>NLLLoss2d <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>NLLLoss2dBackward <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>MSELoss <class 'torch.nn._functions.thnn.auto.MSELoss'>MSELossBackward <class 'torch.nn._functions.thnn.auto.MSELossBackward'>Sigmoid <class 'torch.nn._functions.thnn.auto.Sigmoid'>SigmoidBackward <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>VolumetricUpSamplingTrilinear <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>VolumetricUpSamplingTrilinearBackward <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>BCELoss <class 'torch.nn._functions.thnn.auto.BCELoss'>BCELossBackward <class 'torch.nn._functions.thnn.auto.BCELossBackward'>Square <class 'torch.nn._functions.thnn.auto.Square'>SquareBackward <class 'torch.nn._functions.thnn.auto.SquareBackward'>ReplicationPad2d <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>ReplicationPad2dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>L1Loss <class 'torch.nn._functions.thnn.auto.L1Loss'>L1LossBackward <class 'torch.nn._functions.thnn.auto.L1LossBackward'>SpatialGridSamplerBilinear <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>SpatialGridSamplerBilinearBackward <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>Sqrt <class 'torch.nn._functions.thnn.auto.Sqrt'>SqrtBackward <class 'torch.nn._functions.thnn.auto.SqrtBackward'>TemporalRowConvolution <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>TemporalRowConvolutionBackward <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>SpatialFractionalMaxPooling <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>SpatialFractionalMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>TemporalUpSamplingLinear <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>TemporalUpSamplingLinearBackward <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>VolumetricDilatedMaxPooling <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>VolumetricDilatedMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>Threshold <class 'torch.nn._functions.thnn.auto.Threshold'>ThresholdBackward <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>Abs <class 'torch.nn._functions.thnn.auto.Abs'>AbsBackward <class 'torch.nn._functions.thnn.auto.AbsBackward'>Softshrink <class 'torch.nn._functions.thnn.auto.Softshrink'>SoftshrinkBackward <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>LeakyReLU <class 'torch.nn._functions.thnn.auto.LeakyReLU'>LeakyReLUBackward <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>VolumetricUpSamplingNearest <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>VolumetricUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>VolumetricDilatedConvolution <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>VolumetricDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>Tanh <class 'torch.nn._functions.thnn.auto.Tanh'>TanhBackward <class 'torch.nn._functions.thnn.auto.TanhBackward'>TemporalSubSampling <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>TemporalSubSamplingBackward <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>ELU <class 'torch.nn._functions.thnn.auto.ELU'>ELUBackward <class 'torch.nn._functions.thnn.auto.ELUBackward'>Hardtanh <class 'torch.nn._functions.thnn.auto.Hardtanh'>HardtanhBackward <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>L1Cost <class 'torch.nn._functions.thnn.auto.L1Cost'>L1CostBackward <class 'torch.nn._functions.thnn.auto.L1CostBackward'>SpatialSubSampling <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>SpatialSubSamplingBackward <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>Im2Col <class 'torch.nn._functions.thnn.auto.Im2Col'>Im2ColBackward <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>KLDivLoss <class 'torch.nn._functions.thnn.auto.KLDivLoss'>KLDivLossBackward <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>SmoothL1Loss <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>SmoothL1LossBackward <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>ReflectionPad2d <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>ReflectionPad2dBackward <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>CrossMapLRN2d <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>EmbeddingBag <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>
一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。
其他有序字典
self._parameters = OrderedDict() # 模块网络参数 self._buffers = OrderedDict() # 驻留内存(不释放,不交换) self._backward_hooks = OrderedDict() # 反向钩子函数字典, self._forward_hooks = OrderedDict() # 正向钩子函数字典 self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典 self._modules = OrderedDict() # 模块列表 self.training = True # 训练还是验证
模块函数
模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述
名称 | 作用 |
---|---|
forward | 前向计算虚函数 |
register_buffer | 注册驻留内存 |
register_parameter | 注册参数 |
add_module | 添加模块 |
_apply | 针对所有参数的操作 |
apply | 针对所有子模块的操作 |
cuda | 搬家到GPU上 |
cpu | 搬家到CPU上 |
type | 所有参数换类型喽 |
float | 统统换成浮点 |
double | 统统换成双精度浮点 |
half | 统统换成字(俩字节) |
to | 给用户一个换类型和CGPU的接口,其实还是调用_ |
register_backward_hook | 注册反向钩子 |
register_forward_pre_hook | 注册前向调用前钩子 |
register_forward_hook | 注册前向钩子 |
_slow_forward | 没有加速的前向函数 |
call | 给个参数就执行的前向调用 |
setstate | 快速设置所有字典状态 |
getattr | 获取属性 |
setattr | 设置属性 |
delattr | 删除属性 |
state_dict | 当前状态字典的输出 |
_load_from_state_dict | 从状态字典中装载的执行函数 |
load_state_dict | 装载状态的用户接口 |
children | 子模块 |
modules | 所有模块 |
train | 训练 |
eval | 评估 |
zero_grad | 参数梯度清零 |
share_memory | 使用共享内存 |
repr | 迭代器 |
dir | 列举 |
作者:readilen
链接:https://www.jianshu.com/p/c5b751123ae8
点击查看更多内容
为 TA 点赞
评论
共同学习,写下你的评论
评论加载中...
作者其他优质文章
正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦