NNVM Frontend组件主要负责将多种深度学习框架训练出来的模型转换成如下内容:
nnvm.Graph对象:用于存储模型网络描述
tvm.nd.Array对象:用于存储模型权重参数
NNVM Frontend组件将不同深度学习框架的模型格式统一转换成nnvm.Graph和tvm.nd.array的组合。
本篇文档暂时先只关注nnvm.Graph对象和mxnet模型转换。
相关代码位于:
python/nnvm/frontend/common.py
python/nnvm/frontend/mxnet.py
mxnet模型加载与转换的接口函数为nnvm.frontend.from_mxnet
在介绍转换接口函数前先了解一下nnvm.Graph
这个数据结构,nnvm.Graph
定义于python/nnvm/graph.py
:
nnvm.Graph
用来表示一个graph对象,这个对象可以被用于应用优化pass。它包含了额外的一些计算图级别专用的属性。
class Graph(object): def json_attr(self, key) # 获取属性字符串 def _set_json_attr(self, key, value, type_name=None) # 设置属性 def json(self) # 获取计算图的json表示 def _tvm_graph_json(self) # 获取TVM计算图的json表示 def ir(self, join_entry_attrs=None, join_node_attrs=None) # 获取计算图IR的文本形式 def apply(self, passes) # 针对某个graph应用pass
Graph对象比较重要的一个函数是apply,具体是通过调用NNGraphApplyPasses
来实现。
接下来介绍一下mxnet模型的具体转换过程。
python/nnvm/frontend/mxnet.py
def _convert_symbol(op_name, inputs, attrs, identity_list=None, convert_map=None): identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map if op_name in identity_list: op = _get_nnvm_op(op_name) sym = op(*inputs, **attrs) elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs) else: _raise_not_supported('Operator: ' + op_name) return sym
针对单个运算符的转换过程主要由_convert_symbol
函数完成,其中涉及到两个运算符列表
_identity_list:表示mxnet运算符名称和nnvm一致,并且运算符附带的参数名称也必须一致。
_convert_map:表示mxnet运算符名称或者参数名称和nnvm不一致,必须转换运算符名称或参数名称。
python/nnvm/frontend/mxnet.py
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', '__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__', '__rsub_scalar__', '__sub_scalar__', '__sub_symbol__', 'broadcast_add', 'broadcast_div', 'broadcast_mul', 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] # _convert_map列表较长,只列出部分运算符 _convert_map = { 'Activation' : _activations, 'BatchNorm' : _batch_norm, 'BatchNorm_v1' : _batch_norm, 'Cast' : _rename('cast'), 'Concat' : _concat, 'Convolution' : _conv2d, 'Convolution_v1': _conv2d, 'Deconvolution' : _conv2d_transpose, 'Dropout' : _dropout, }
获取到运算符名称之后可以通过_get_nnvm_op
函数来获取nnvm运算符
python/nnvm/frontend/mxnet.py
from .. import symbol as _sym def _get_nnvm_op(op_name): op = getattr(_sym, op_name) if not op: raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) return op
_get_nnvm_op
的主要功能是通过getattr
内建函数来获取nnvm op对象,这个函数能获取到所有通过NNVM_REGISTER_OP
注册的运算符
共同学习,写下你的评论
评论加载中...
作者其他优质文章