为了账号安全,请及时绑定邮箱和手机立即绑定

6 Module -庖丁解牛之pytorch

标签:
大数据

Module存储了模块类的函数

pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么

class Module(object):
  def __init__(self):        self._backend = thnn_backend        self._parameters = OrderedDict()        self._buffers = OrderedDict()        self._backward_hooks = OrderedDict()        self._forward_hooks = OrderedDict()        self._forward_pre_hooks = OrderedDict()        self._modules = OrderedDict()        self.training = True

构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend

class FunctionBackend(object):
    def __init__(self):        self.function_classes = {}    def register_function(self, name, function_class):        self.function_classes[name] = function_class

其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1

RNN                                      <function RNN at 0x7f4330534378>
RNNTanhCell                              <function RNNTanhCell at 0x7f4330530d90>
RNNReLUCell                              <function RNNReLUCell at 0x7f43305309d8>
LSTMCell                                 <function LSTMCell at 0x7f4330530e18>
GRUCell                                  <function GRUCell at 0x7f4330530ea0>
Dropout                                  <class 'torch.nn._functions.dropout.Dropout'>Dropout2d                                <class 'torch.nn._functions.dropout.FeatureDropout'>Dropout3d                                <class 'torch.nn._functions.dropout.FeatureDropout'>MarginCriterion                          <class 'torch.nn._functions.thnn.auto.MarginCriterion'>MarginCriterionBackward                  <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>GatedLinear                              <class 'torch.nn._functions.thnn.auto.GatedLinear'>GatedLinearBackward                      <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>SpatialFullConvolutionMap                <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>SpatialFullConvolutionMapBackward        <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>VolumetricFractionalMaxPooling           <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>VolumetricFractionalMaxPoolingBackward   <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>VolumetricFullDilatedConvolution         <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>Col2Im                                   <class 'torch.nn._functions.thnn.auto.Col2Im'>Col2ImBackward                           <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>DilatedConv2d                            <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>DilatedConv2dBackward                    <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>SpatialConvolutionLocal                  <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>SpatialConvolutionLocalBackward          <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>FeatureLPPooling                         <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>FeatureLPPoolingBackward                 <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>VolumetricGridSamplerBilinear            <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>VolumetricGridSamplerBilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>TemporalUpSamplingNearest                <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>TemporalUpSamplingNearestBackward        <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>SpatialUpSamplingNearest                 <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>SpatialUpSamplingNearestBackward         <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>ReflectionPad1d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>ReflectionPad1dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>SpatialConvolutionMap                    <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>SpatialConvolutionMapBackward            <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>NLLLoss                                  <class 'torch.nn._functions.thnn.auto.NLLLoss'>NLLLossBackward                          <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>Softplus                                 <class 'torch.nn._functions.thnn.auto.Softplus'>SoftplusBackward                         <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>LogSigmoid                               <class 'torch.nn._functions.thnn.auto.LogSigmoid'>LogSigmoidBackward                       <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>SpatialUpSamplingBilinear                <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>SpatialUpSamplingBilinearBackward        <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>ReplicationPad3d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>ReplicationPad3dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>MultiMarginLoss                          <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>MultiMarginLossBackward                  <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>ReplicationPad1d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>ReplicationPad1dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>MultiLabelMarginLoss                     <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>MultiLabelMarginLossBackward             <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>SpatialFullDilatedConvolution            <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>SpatialFullDilatedConvolutionBackward    <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>SoftMarginLoss                           <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>SoftMarginLossBackward                   <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>NLLLoss2d                                <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>NLLLoss2dBackward                        <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>MSELoss                                  <class 'torch.nn._functions.thnn.auto.MSELoss'>MSELossBackward                          <class 'torch.nn._functions.thnn.auto.MSELossBackward'>Sigmoid                                  <class 'torch.nn._functions.thnn.auto.Sigmoid'>SigmoidBackward                          <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>VolumetricUpSamplingTrilinear            <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>VolumetricUpSamplingTrilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>BCELoss                                  <class 'torch.nn._functions.thnn.auto.BCELoss'>BCELossBackward                          <class 'torch.nn._functions.thnn.auto.BCELossBackward'>Square                                   <class 'torch.nn._functions.thnn.auto.Square'>SquareBackward                           <class 'torch.nn._functions.thnn.auto.SquareBackward'>ReplicationPad2d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>ReplicationPad2dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>L1Loss                                   <class 'torch.nn._functions.thnn.auto.L1Loss'>L1LossBackward                           <class 'torch.nn._functions.thnn.auto.L1LossBackward'>SpatialGridSamplerBilinear               <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>SpatialGridSamplerBilinearBackward       <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>Sqrt                                     <class 'torch.nn._functions.thnn.auto.Sqrt'>SqrtBackward                             <class 'torch.nn._functions.thnn.auto.SqrtBackward'>TemporalRowConvolution                   <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>TemporalRowConvolutionBackward           <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>SpatialFractionalMaxPooling              <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>SpatialFractionalMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>TemporalUpSamplingLinear                 <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>TemporalUpSamplingLinearBackward         <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>VolumetricDilatedMaxPooling              <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>VolumetricDilatedMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>Threshold                                <class 'torch.nn._functions.thnn.auto.Threshold'>ThresholdBackward                        <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>Abs                                      <class 'torch.nn._functions.thnn.auto.Abs'>AbsBackward                              <class 'torch.nn._functions.thnn.auto.AbsBackward'>Softshrink                               <class 'torch.nn._functions.thnn.auto.Softshrink'>SoftshrinkBackward                       <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>LeakyReLU                                <class 'torch.nn._functions.thnn.auto.LeakyReLU'>LeakyReLUBackward                        <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>VolumetricUpSamplingNearest              <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>VolumetricUpSamplingNearestBackward      <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>VolumetricDilatedConvolution             <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>VolumetricDilatedConvolutionBackward     <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>Tanh                                     <class 'torch.nn._functions.thnn.auto.Tanh'>TanhBackward                             <class 'torch.nn._functions.thnn.auto.TanhBackward'>TemporalSubSampling                      <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>TemporalSubSamplingBackward              <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>ELU                                      <class 'torch.nn._functions.thnn.auto.ELU'>ELUBackward                              <class 'torch.nn._functions.thnn.auto.ELUBackward'>Hardtanh                                 <class 'torch.nn._functions.thnn.auto.Hardtanh'>HardtanhBackward                         <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>L1Cost                                   <class 'torch.nn._functions.thnn.auto.L1Cost'>L1CostBackward                           <class 'torch.nn._functions.thnn.auto.L1CostBackward'>SpatialSubSampling                       <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>SpatialSubSamplingBackward               <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>Im2Col                                   <class 'torch.nn._functions.thnn.auto.Im2Col'>Im2ColBackward                           <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>KLDivLoss                                <class 'torch.nn._functions.thnn.auto.KLDivLoss'>KLDivLossBackward                        <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>SmoothL1Loss                             <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>SmoothL1LossBackward                     <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>ReflectionPad2d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>ReflectionPad2dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>CrossMapLRN2d                            <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>EmbeddingBag                             <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>

一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。

其他有序字典

        self._parameters = OrderedDict() # 模块网络参数
        self._buffers = OrderedDict()       # 驻留内存(不释放,不交换)
        self._backward_hooks = OrderedDict() # 反向钩子函数字典,
        self._forward_hooks = OrderedDict() # 正向钩子函数字典
        self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典
        self._modules = OrderedDict() # 模块列表
        self.training = True # 训练还是验证

模块函数

模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述

名称作用
forward前向计算虚函数
register_buffer注册驻留内存
register_parameter注册参数
add_module添加模块
_apply针对所有参数的操作
apply针对所有子模块的操作
cuda搬家到GPU上
cpu搬家到CPU上
type所有参数换类型喽
float统统换成浮点
double统统换成双精度浮点
half统统换成字(俩字节)
to给用户一个换类型和CGPU的接口,其实还是调用_
register_backward_hook注册反向钩子
register_forward_pre_hook注册前向调用前钩子
register_forward_hook注册前向钩子
_slow_forward没有加速的前向函数
call给个参数就执行的前向调用
setstate快速设置所有字典状态
getattr获取属性
setattr设置属性
delattr删除属性
state_dict当前状态字典的输出
_load_from_state_dict从状态字典中装载的执行函数
load_state_dict装载状态的用户接口
children子模块
modules所有模块
train训练
eval评估
zero_grad参数梯度清零
share_memory使用共享内存
repr迭代器
dir列举



作者:readilen
链接:https://www.jianshu.com/p/c5b751123ae8


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消