我们提供了多种工具来支持模型转换
PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但与此同时 Python 的流程控制和第三方库会使得网络导出为中间语言的过程变得困难。为此我们提供了一个“MonKey path”工具将不支持的功能重写为另一个可支持中间语言导出的功能。下述是一个具体的使用例子:
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
使用函数重写器是十分容易的,只需添加一个带参数的装饰器即可:
func_name
是需要被重载的函数,它可以是其他PyTorch 的函数或者是自定义的函数。模块中的方法也可以通过工具进行重载。backend
是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。 当参数与原始的参数相同时,除了把上下文信息ctx
作为第一的参数外,上下文也提供了一些有用的信息,例如:部署的配置ctx.cfg
和原始的函数(已经被重载)ctx.origin_func
。
可参照这些样例代码。
如果您想用另一个模块替换整个模块,我们还有另一个重载器,如下所示:
@MODULE_REWRITER.register_rewrite_module(
'mmagic.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):
def __init__(self,
module,
cfg,
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4):
super(SRCNNWrapper, self).__init__()
self._module = module
module.img_upsampler = nn.Upsample(
scale_factor=module.upscale_factor,
mode='bilinear',
align_corners=False)
def forward(self, *args, **kwargs):
"""Run forward."""
return self._module(*args, **kwargs)
def init_weights(self, *args, **kwargs):
"""Initialize weights."""
return self._module.init_weights(*args, **kwargs)
就像函数重载器一样,可添加一个带参数的装饰器:
module_type
要重载的模块类。backend
是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。
网络中模块的所有实例都将替换为这个新类的实例。原始模块和部署配置将作为前两个参数进行传递。
PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义的。自定义符号函数可以帮助我们绕过一些推理引擎不支持的 ONNX 节点。
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op('Squeeze', self, axes_i=dims)
装饰器的参数
func_name
要添加符号的函数名称。如果是自定义的,请使用完整路径torch.autograd.Function
。或者如果它是 PyTorch 内置函数,则只用写一个名称即可。backend
是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。- 如果函数是 PyTorch 内置函数,则为True。
arg_descriptors
符号函数参数的描述符,将被传递给torch.onnx.symbolic_helper._parse_arg
。
就像函数重载器的ctx
一样,第一个参数会提供上下文信息。上下文中了一些有用的信息,例如部署配置ctx.cfg和原始功能(已被重载)ctx.origin_func
。请注意, ctx.origin_func
只能在is_pytorch==False
时使用。
这里有很多实现可参考。