Skip to content

Latest commit

 

History

History
91 lines (71 loc) · 4.74 KB

support_new_model.md

File metadata and controls

91 lines (71 loc) · 4.74 KB

如何支持新的模型

我们提供了多种工具来支持模型转换

函数的重写器

PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但与此同时 Python 的流程控制和第三方库会使得网络导出为中间语言的过程变得困难。为此我们提供了一个“MonKey path”工具将不支持的功能重写为另一个可支持中间语言导出的功能。下述是一个具体的使用例子:

from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
    func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(input, *size):
    ctx = FUNCTION_REWRITER.get_context()
    origin_func = ctx.origin_func
    if input.dim() == 1 and len(size) == 1:
        return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
    else:
        return origin_func(input, *size)

使用函数重写器是十分容易的,只需添加一个带参数的装饰器即可:

  • func_name是需要被重载的函数,它可以是其他PyTorch 的函数或者是自定义的函数。模块中的方法也可以通过工具进行重载。
  • backend是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。 当参数与原始的参数相同时,除了把上下文信息ctx 作为第一的参数外,上下文也提供了一些有用的信息,例如:部署的配置ctx.cfg 和原始的函数(已经被重载)ctx.origin_func

可参照这些样例代码

模型重载器

如果您想用另一个模块替换整个模块,我们还有另一个重载器,如下所示:

@MODULE_REWRITER.register_rewrite_module(
    'mmagic.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):
    def __init__(self,
                 module,
                 cfg,
                 channels=(3, 64, 32, 3),
                 kernel_sizes=(9, 1, 5),
                 upscale_factor=4):
        super(SRCNNWrapper, self).__init__()
        self._module = module
        module.img_upsampler = nn.Upsample(
            scale_factor=module.upscale_factor,
            mode='bilinear',
            align_corners=False)
    def forward(self, *args, **kwargs):
        """Run forward."""
        return self._module(*args, **kwargs)
    def init_weights(self, *args, **kwargs):
        """Initialize weights."""
        return self._module.init_weights(*args, **kwargs)

就像函数重载器一样,可添加一个带参数的装饰器:

  • module_type 要重载的模块类。
  • backend 是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。

网络中模块的所有实例都将替换为这个新类的实例。原始模块和部署配置将作为前两个参数进行传递。

符号函数重写

PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义的。自定义符号函数可以帮助我们绕过一些推理引擎不支持的 ONNX 节点。

@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(g, self, dim=None):
    if dim is None:
        dims = []
        for i, size in enumerate(self.type().sizes()):
            if size == 1:
                dims.append(i)
    else:
        dims = [sym_help._get_const(dim, 'i', 'dim')]
    return g.op('Squeeze', self, axes_i=dims)

装饰器的参数

  • func_name要添加符号的函数名称。如果是自定义的,请使用完整路径torch.autograd.Function。或者如果它是 PyTorch 内置函数,则只用写一个名称即可。
  • backend是推理引擎。当模型被导入到引擎的时候,函数会被重载。如果没有给出,重载默认的参数就是重载的参数。如果后端的重载的参数不存在,将会按照预设的默认模式进行重载。
  • 如果函数是 PyTorch 内置函数,则为True。
  • arg_descriptors 符号函数参数的描述符,将被传递给torch.onnx.symbolic_helper._parse_arg

就像函数重载器的ctx一样,第一个参数会提供上下文信息。上下文中了一些有用的信息,例如部署配置ctx.cfg和原始功能(已被重载)ctx.origin_func。请注意, ctx.origin_func只能在is_pytorch==False时使用。

这里有很多实现可参考。