From adb792ec00629d69aae309d896f6611c464dca64 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sat, 30 Sep 2023 00:32:54 +0800 Subject: [PATCH 01/20] impl basic cache (only tensor inputs) --- .gitignore | 6 ++++ examples/simple_compiler.py | 21 ++++++++------ src/paddlefx/cache_manager.py | 16 +++++++++++ src/paddlefx/compiler.py | 11 ++++++-- src/paddlefx/convert_frame.py | 29 ++++++++++++-------- src/paddlefx/output_graph.py | 32 +++++++++++++++++++++- src/paddlefx/pyeval.py | 50 +++++++++++++++++++++++++++++++++- src/paddlefx/source.py | 7 +++++ src/paddlefx/variables/base.py | 32 ++++++++++++++++++++++ 9 files changed, 179 insertions(+), 25 deletions(-) create mode 100644 src/paddlefx/cache_manager.py diff --git a/.gitignore b/.gitignore index ca9a3b9..57995ad 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,9 @@ src/paddlefx/_version.py .cache/ *.so tmp/ + +# viztracer +result.json + +# mlir +*.mlir diff --git a/examples/simple_compiler.py b/examples/simple_compiler.py index fd22c76..67f2fd0 100644 --- a/examples/simple_compiler.py +++ b/examples/simple_compiler.py @@ -1,7 +1,5 @@ from __future__ import annotations -import logging - import numpy as np import paddle import paddle.nn @@ -13,7 +11,7 @@ paddle.seed(0) -logging.getLogger().setLevel(logging.DEBUG) +# logging.getLogger().setLevel(logging.DEBUG) def inner_func(x, y): @@ -25,14 +23,19 @@ def inner_func(x, y): def func(a, b): d = inner_func(a, b) + d = inner_func(a, d) + d = inner_func(d, a) + d = inner_func(a, d) return d -optimized_net = paddlefx.optimize(func, backend=TVMCompiler(print_tabular=True)) +optimized_func = paddlefx.optimize(func, backend=TVMCompiler(print_tabular=True)) -x = paddle.rand([1, 224]) -y = paddle.rand([1, 224]) -out = func(x, y) -res = optimized_net(x, y) +x = paddle.rand([4, 6, 1]) +y = paddle.rand([4, 6, 224]) +for _ in range(10): + out = func(y, x) + res = optimized_func(x, y) + res = optimized_func(y, x) -np.testing.assert_equal(res.numpy(), out.numpy()) + np.testing.assert_equal(res.numpy(), out.numpy()) diff --git a/src/paddlefx/cache_manager.py b/src/paddlefx/cache_manager.py new file mode 100644 index 0000000..a0f36fa --- /dev/null +++ b/src/paddlefx/cache_manager.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import dataclasses +import types + +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + GuardFunction = Callable[[types.FrameType], bool] + GuardedCodes = list["GuardedCode"] + + +@dataclasses.dataclass +class GuardedCode: + code: types.CodeType + guard_fn: GuardFunction diff --git a/src/paddlefx/compiler.py b/src/paddlefx/compiler.py index a3a4c33..2107cbb 100644 --- a/src/paddlefx/compiler.py +++ b/src/paddlefx/compiler.py @@ -45,18 +45,22 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: for node in gl.graph.nodes: getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) self.input_index = 0 - return self.gen_compiled_func(symbol_table, dummy_outputs) + return self.gen_compiled_func(symbol_table, example_inputs, dummy_outputs) except (AttributeError, NotImplementedError) as e: print(f"AttributeError when compiling graph: {e}") self.input_index = 0 return gl.forward - def gen_compiled_func(self, symbol_table: dict[str, Any], dummy_outputs: Any): + def gen_compiled_func( + self, symbol_table: dict[str, Any], dummy_inputs: list, dummy_outputs: Any + ): raise NotImplementedError("CompilerBase is a abstract class") class TVMCompiler(CompilerBase): - def gen_compiled_func(self, symbol_table: dict[str, te.Tensor], dummy_outputs: Any): + def gen_compiled_func( + self, symbol_table: dict[str, te.Tensor], dummy_inputs: list, dummy_outputs: Any + ): import tvm from tvm import te @@ -84,6 +88,7 @@ def compiled_func(*args): output = paddle.to_tensor(output.asnumpy()) return (output,) + compiled_func(*dummy_inputs) return compiled_func def compile_placeholder( diff --git a/src/paddlefx/convert_frame.py b/src/paddlefx/convert_frame.py index 150015a..2d4a642 100644 --- a/src/paddlefx/convert_frame.py +++ b/src/paddlefx/convert_frame.py @@ -1,20 +1,18 @@ from __future__ import annotations -import dataclasses import logging import types -from typing import Callable +from typing import TYPE_CHECKING, Callable from .bytecode_transformation import Instruction, transform_code_object +from .cache_manager import GuardedCode from .paddle_utils import Tensor, skip_paddle_filename, skip_paddle_frame -from .pyeval import PyEval +from .pyeval import CodeCacheManager, PyEval from .utils import log_bytecode, log_code - -@dataclasses.dataclass -class GuardedCode: - code: types.CodeType +if TYPE_CHECKING: + pass def skip_frame(frame: types.FrameType) -> bool: @@ -36,22 +34,31 @@ def convert_frame(frame: types.FrameType, compiler_fn: Callable) -> GuardedCode logging.debug(f"skip_frame: {frame}") return None + # TODO: guard_fn is not declared in this scope + guard_fn = None + def transform(instructions: list[Instruction], code_options: dict): tracer = PyEval(instructions, frame, code_options, compiler_fn) tracer.run() + nonlocal guard_fn + guard_fn = tracer.output.guard_fn code_options.update(tracer.output.code_options) instructions[:] = tracer.output.instructions logging.info(f"convert_frame: {frame}") code = frame.f_code - log_code(code, "RAW_BYTECODE") + log_code(code, "ORIGINAL_BYTECODE") + + if (cached_code := CodeCacheManager.get_cache(frame)) is not None: + logging.info(f"cached_code: {cached_code}") + return cached_code # TODO: rm torch code dependency out_code = transform_code_object(code, transform) log_bytecode( "NEW_BYTECODE", code.co_name, code.co_filename, code.co_firstlineno, out_code ) - - g = GuardedCode(out_code) - return g + new_code = GuardedCode(out_code, guard_fn) + CodeCacheManager.add_cache(code, new_code) + return new_code diff --git a/src/paddlefx/output_graph.py b/src/paddlefx/output_graph.py index 9d7a8f0..3f96a83 100644 --- a/src/paddlefx/output_graph.py +++ b/src/paddlefx/output_graph.py @@ -13,11 +13,13 @@ from .graph import Graph from .graph_layer import GraphLayer from .node import Node -from .source import LocalSource +from .source import GlobalSource, LocalSource from .utils import format_instruction, log_code, log_instructions +from .variables.base import TensorVariable, find_traceable_vars from .variables.builder import GraphArg if TYPE_CHECKING: + from .cache_manager import GuardFunction from .pyeval import PyEval, PyEvalBase from .variables.base import VariableBase @@ -35,6 +37,7 @@ def __init__( root_tx: PyEval, ): self.instructions: list[Instruction] = [] + self.input_variables: list[VariableBase] = [] self.code_options = code_options self.compiler_fn = compiler_fn self.root_tx = root_tx @@ -57,6 +60,33 @@ def placeholders(self) -> list[Node]: def graphargs(self) -> list[GraphArg]: return [node.meta["grapharg"] for node in self.placeholders] + @property + def guard_fn(self) -> GuardFunction: + str_guards: list[str] = [] + + for variable in find_traceable_vars(self.input_variables): + # TODO: add global_guarded_variables + # TODO: define make_guard in VariableBase + if isinstance(variable, TensorVariable): + assert variable.source is not None + if isinstance(variable.source, LocalSource): + var_name = f"frame.f_locals['{variable.source.local_name}']" + elif isinstance(variable.source, GlobalSource): + var_name = f"frame.f_globals['{variable.source.global_name}']" + else: + raise ValueError(f"Unsupported source: {variable.source}") + + str_guards.extend( + [ + f"str({var_name}.shape) == '{variable.var.shape}'", + f"str({var_name}.dtype) == '{variable.var.dtype}'", + ] + ) + if len(str_guards) == 0: + return lambda frame: True + guard_string = f"lambda frame: {' and '.join(str_guards)}" + return eval(guard_string) + def add_output_instructions(self, insts: list[Instruction]) -> None: self.instructions.extend(insts) self.should_exit = True diff --git a/src/paddlefx/pyeval.py b/src/paddlefx/pyeval.py index 68320e0..3b624e1 100644 --- a/src/paddlefx/pyeval.py +++ b/src/paddlefx/pyeval.py @@ -21,6 +21,7 @@ transform_code_object, unique_id, ) +from .cache_manager import GuardedCode from .codegen import PyCodegen from .output_graph import OutputGraph from .paddle_utils import TensorType @@ -33,7 +34,51 @@ if TYPE_CHECKING: # import opcode - pass + from .cache_manager import GuardedCodes + + +class CodeCacheManager: + cache_dict: dict[types.CodeType, GuardedCodes] = {} + + @classmethod + def add_cache(cls, code: types.CodeType, guarded_code: GuardedCode): + cls.cache_dict.setdefault(code, []) + cls.cache_dict[code].append(guarded_code) + + @classmethod + def get_cache(cls, frame: types.FrameType) -> GuardedCode | None: + code: types.CodeType = frame.f_code + if code not in cls.cache_dict: + print(f"Firstly call {code}\n") + return None + return cls.lookup(frame, cls.cache_dict[code]) + + @classmethod + def clear_cache(cls): + cls.cache_dict.clear() + + @classmethod + def lookup( + cls, frame: types.FrameType, guarded_codes: GuardedCodes + ) -> GuardedCode | None: + for guarded_code in guarded_codes: + try: + guard_fn = guarded_code.guard_fn + if guard_fn(frame): + print( + f"[Cache]: Cache hit, GuardFunction is {guard_fn}\n", + ) + return guarded_code + else: + print( + f"[Cache]: Cache miss, GuardFunction is {guard_fn}\n", + ) + except Exception as e: + print(f"[Cache]: GuardFunction function error: {e}\n") + continue + + print("[Cache]: all guards missed\n") + return None def tos_op_wrapper(fn: Callable): @@ -687,6 +732,9 @@ def __init__( ): node = self.output.graph.placeholder(var.source.local_name) node.meta["grapharg"] = GraphArg(example=var.var) + # TODO: self.output.global_variables + # TODO: add those in PyEvalInline + self.output.input_variables.append(var) def create_call_resume_at(self, inst: Instruction | None) -> list[Instruction]: assert inst is not None diff --git a/src/paddlefx/source.py b/src/paddlefx/source.py index bc8dcea..ef46b54 100644 --- a/src/paddlefx/source.py +++ b/src/paddlefx/source.py @@ -7,6 +7,13 @@ class Source: def name(self) -> str: raise NotImplementedError() + def is_traceable(self) -> bool: + raise NotImplementedError() + + def need_guard(self) -> bool: + # TODO(zrr1999): implement is_traceable + return True + @dataclasses.dataclass(frozen=True) class LocalSource(Source): diff --git a/src/paddlefx/variables/base.py b/src/paddlefx/variables/base.py index 91f19ef..5a4e208 100644 --- a/src/paddlefx/variables/base.py +++ b/src/paddlefx/variables/base.py @@ -2,6 +2,7 @@ import itertools +from queue import Queue from typing import TYPE_CHECKING, Any from ..source import LocalSource, Source @@ -12,6 +13,37 @@ from ..pyeval import PyEvalBase +def find_traceable_vars( + root_vars: list[VariableBase], +) -> list[VariableBase]: + """This function is used to find all traceable variables in the given list of variables. + + Args: + root_vars (list[VariableBase]): A list of root variables from which the ordering starts. + + Returns: + list[VariableBase]: A list of variables that are traceable. + """ + results: list[VariableBase] = [] + visited: set[VariableBase] = set() + queue: Queue[VariableBase] = Queue() + + for root in root_vars: + queue.put(root) + + while not queue.empty(): + var = queue.get() + if var in visited: + continue + + visited.add(var) + if var.source is not None and var.source.need_guard(): + results.append(var) + continue + + return results + + class VariableBase: def __init__( self, From ed9d458eb6dda2ba75db6f010422515159470cd9 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sat, 30 Sep 2023 13:03:20 +0800 Subject: [PATCH 02/20] suuport simple_dynamo --- examples/resnet_trace.py | 2 ++ examples/simple_compiler.py | 2 +- examples/{TODO => }/simple_dynamo.py | 26 ++++++++++---------------- src/paddlefx/compiler.py | 20 +++++++++++++++----- src/paddlefx/variables/callable.py | 6 +++--- 5 files changed, 31 insertions(+), 25 deletions(-) rename examples/{TODO => }/simple_dynamo.py (72%) diff --git a/examples/resnet_trace.py b/examples/resnet_trace.py index 4e108c9..84ff910 100644 --- a/examples/resnet_trace.py +++ b/examples/resnet_trace.py @@ -7,6 +7,8 @@ from paddlefx import symbolic_trace +paddle.seed(0) + net = resnet18() traced_layer = symbolic_trace(net) diff --git a/examples/simple_compiler.py b/examples/simple_compiler.py index 67f2fd0..e0a40e8 100644 --- a/examples/simple_compiler.py +++ b/examples/simple_compiler.py @@ -34,8 +34,8 @@ def func(a, b): x = paddle.rand([4, 6, 1]) y = paddle.rand([4, 6, 224]) for _ in range(10): - out = func(y, x) res = optimized_func(x, y) res = optimized_func(y, x) + out = func(y, x) np.testing.assert_equal(res.numpy(), out.numpy()) diff --git a/examples/TODO/simple_dynamo.py b/examples/simple_dynamo.py similarity index 72% rename from examples/TODO/simple_dynamo.py rename to examples/simple_dynamo.py index 6c5b195..d51fba6 100644 --- a/examples/TODO/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -1,31 +1,24 @@ from __future__ import annotations -import logging - import numpy as np import paddle import paddle.nn import paddlefx -logging.getLogger().setLevel(logging.DEBUG) - +from paddlefx.compiler import TVMCompiler -def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None): - print("my_compiler() called with FX graph:") - gl.graph.print_tabular() - print(gl.get_source()) - return gl.forward +# logging.getLogger().setLevel(logging.DEBUG) -@paddlefx.optimize(backend=my_compiler) +@paddlefx.optimize(backend=TVMCompiler(print_tabular=True)) def add(a, b): print('\tcall add') c = a + b return c -@paddlefx.optimize(backend=my_compiler) +@paddlefx.optimize(backend=TVMCompiler(print_tabular=True)) def func(a, b): print('\tcall func') c = add(a, b) @@ -56,7 +49,7 @@ def foo(a, b): return l -optimized_foo = paddlefx.optimize(backend=my_compiler)(foo) +optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(foo) original_res = foo(in_a, in_b) optimized_res = optimized_foo(in_a, in_b) @@ -64,8 +57,8 @@ def foo(a, b): np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) dtype = 'float32' -in_a = paddle.to_tensor([1], dtype=dtype) -in_b = paddle.to_tensor([0], dtype=dtype) +in_a = paddle.to_tensor([1, 2], dtype=dtype) +in_b = paddle.to_tensor([0, 1], dtype=dtype) def inplace(a, b): @@ -79,7 +72,7 @@ def inplace(a, b): return a -optimized_foo = paddlefx.optimize(backend=my_compiler)(inplace) +optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(inplace) original_res = inplace(in_a, in_b) optimized_res = optimized_foo(in_a, in_b) @@ -100,9 +93,10 @@ def forward(self, a, b): net = ExampleNet() -optimized_func = paddlefx.optimize(backend=my_compiler)(net) +optimized_func = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(net) original_res = net(in_a, in_b) optimized_res = optimized_func(in_a, in_b) +optimized_res = optimized_func(in_a, in_b) # TODO(zrr1999): `optimized_res` is the result of running the converted bytecode in the future. np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) diff --git a/src/paddlefx/compiler.py b/src/paddlefx/compiler.py index 2107cbb..9ab3ffc 100644 --- a/src/paddlefx/compiler.py +++ b/src/paddlefx/compiler.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Callable import paddle +import paddle.device import paddlefx @@ -65,7 +66,13 @@ def gen_compiled_func( from tvm import te - tgt = tvm.target.Target(target="llvm", host="llvm") + device = paddle.device.get_device() + if device == "cpu": + target = tvm.target.Target(target="llvm", host="llvm") + elif device == "gpu": + target = tvm.target.Target(target="cuda", host="llvm") + else: + raise ValueError(f"Unsupported device in tvm backend: {device}") schedule = te.create_schedule(symbol_table["output"].op) tvm_func = tvm.build( schedule, @@ -74,7 +81,7 @@ def gen_compiled_func( for k, v in symbol_table.items() if v.name.startswith("input") or k == "output" ], - tgt, + target, name=symbol_table["output"].name, ) @@ -116,12 +123,15 @@ def compile_call_function( "subtract": topi.subtract, "mul": topi.multiply, "truediv": topi.divide, + "gt": topi.greater, + "lt": topi.less, + "ge": topi.greater_equal, + "le": topi.less_equal, } if target_name in map_ops_to_tvm.keys(): - left = symbol_table[str(node.args[0])] - right = symbol_table[str(node.args[1])] - symbol_table[node.name] = map_ops_to_tvm[target_name](left, right) + symbol_args = [symbol_table[str(arg)] for arg in node.args] + symbol_table[node.name] = map_ops_to_tvm[target_name](*symbol_args) else: raise NotImplementedError(f"Unsupported function: {target_name}") diff --git a/src/paddlefx/variables/callable.py b/src/paddlefx/variables/callable.py index df02911..bd457c4 100644 --- a/src/paddlefx/variables/callable.py +++ b/src/paddlefx/variables/callable.py @@ -111,16 +111,16 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas obj_cls = type(args[0]) output = graph.call_function(fn, args, kwargs, ot) return TensorVariable(None, node=output) - elif fn in [operator.gt]: + elif fn in [operator.gt, operator.lt, operator.ge, operator.le]: ot = type(args[0].var) obj_cls = type(args[0]) output = graph.call_function(fn, args, kwargs, ot) - return obj_cls(node=output) + return TensorVariable(None, node=output) elif fn in [operator.is_, operator.is_not]: ot = type(args[0].var) obj_cls = type(args[0]) output = graph.call_function(fn, args, kwargs, ot) - return obj_cls(node=output) + return TensorVariable(None, node=output) else: raise NotImplementedError(f"builtin {fn} is not supported") From 26e464b16464d46fb1e8ac62f44bc27a461ebe03 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 17:48:57 +0800 Subject: [PATCH 03/20] use dir --- examples/simple_compiler.py | 4 +- examples/simple_dynamo.py | 32 ++------- examples/targets/target_3_add_paddle.py | 4 +- src/paddlefx/compiler/__init__.py | 4 ++ src/paddlefx/compiler/base.py | 68 +++++++++++++++++++ src/paddlefx/{compiler.py => compiler/tvm.py} | 61 ++++------------- src/paddlefx/eval_frame.py | 4 +- src/paddlefx/node.py | 7 +- tests/test_basic.py | 27 +++++--- tests/test_broadcast.py | 27 -------- 10 files changed, 119 insertions(+), 119 deletions(-) create mode 100644 src/paddlefx/compiler/__init__.py create mode 100644 src/paddlefx/compiler/base.py rename src/paddlefx/{compiler.py => compiler/tvm.py} (62%) delete mode 100644 tests/test_broadcast.py diff --git a/examples/simple_compiler.py b/examples/simple_compiler.py index e0a40e8..427c3b9 100644 --- a/examples/simple_compiler.py +++ b/examples/simple_compiler.py @@ -16,7 +16,7 @@ def inner_func(x, y): p = paddle.add(x, y) - q = paddle._C_ops.subtract(x, y) + q = paddle._C_ops.subtract(x, y) # type: ignore z = p * q return z / y @@ -29,7 +29,7 @@ def func(a, b): return d -optimized_func = paddlefx.optimize(func, backend=TVMCompiler(print_tabular=True)) +optimized_func = paddlefx.optimize(func, backend=TVMCompiler(print_tabular_mode="rich")) x = paddle.rand([4, 6, 1]) y = paddle.rand([4, 6, 224]) diff --git a/examples/simple_dynamo.py b/examples/simple_dynamo.py index d51fba6..695998d 100644 --- a/examples/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -11,14 +11,14 @@ # logging.getLogger().setLevel(logging.DEBUG) -@paddlefx.optimize(backend=TVMCompiler(print_tabular=True)) +@paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich")) def add(a, b): print('\tcall add') c = a + b return c -@paddlefx.optimize(backend=TVMCompiler(print_tabular=True)) +@paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich")) def func(a, b): print('\tcall func') c = add(a, b) @@ -34,28 +34,6 @@ def func(a, b): np.testing.assert_equal(res.numpy(), out.numpy()) -def foo(a, b): - # print('\tcall foo') - c = a / b - d = a * b - e = c + d - f = e - a - g = f > e - h = g < f - i = h <= g - j = i >= i - k = j == i - l = j != k - return l - - -optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(foo) - -original_res = foo(in_a, in_b) -optimized_res = optimized_foo(in_a, in_b) - -np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) - dtype = 'float32' in_a = paddle.to_tensor([1, 2], dtype=dtype) in_b = paddle.to_tensor([0, 1], dtype=dtype) @@ -72,7 +50,9 @@ def inplace(a, b): return a -optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(inplace) +optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich"))( + inplace +) original_res = inplace(in_a, in_b) optimized_res = optimized_foo(in_a, in_b) @@ -93,7 +73,7 @@ def forward(self, a, b): net = ExampleNet() -optimized_func = paddlefx.optimize(backend=TVMCompiler(print_tabular=True))(net) +optimized_func = paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich"))(net) original_res = net(in_a, in_b) optimized_res = optimized_func(in_a, in_b) diff --git a/examples/targets/target_3_add_paddle.py b/examples/targets/target_3_add_paddle.py index 27013eb..e07ba2a 100644 --- a/examples/targets/target_3_add_paddle.py +++ b/examples/targets/target_3_add_paddle.py @@ -20,11 +20,11 @@ def func(x, y): z = paddle.add(x, y) - o = paddle._C_ops.add(z, z) + o = paddle._C_ops.add(z, z) # type: ignore return o -@paddlefx.optimize(backend=TVMCompiler(print_tabular=True)) +@paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich")) def net(a, b): c = func(a, b) return c diff --git a/src/paddlefx/compiler/__init__.py b/src/paddlefx/compiler/__init__.py new file mode 100644 index 0000000..c4b13e7 --- /dev/null +++ b/src/paddlefx/compiler/__init__.py @@ -0,0 +1,4 @@ +from __future__ import annotations + +from .base import CompilerBase, DummyCompiler # noqa: F401 +from .tvm import TVMCompiler # noqa: F401 diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py new file mode 100644 index 0000000..84eecd0 --- /dev/null +++ b/src/paddlefx/compiler/base.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any, Callable + +import paddle +import paddle.device + +import paddlefx + + +def paddle_dtype_to_str(dtype: paddle.dtype) -> str: + if dtype == paddle.float32: + return "float32" + elif dtype == paddle.float64: + return "float64" + elif dtype == paddle.float16: + return "float16" + elif dtype == paddle.int32: + return "int32" + elif dtype == paddle.int64: + return "int64" + elif dtype == paddle.bool: + return "bool" + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +class CompilerError(Exception): + pass + + +class CompilerBase: + def __init__(self, *, full_graph=False, print_tabular_mode: str | None = None): + self.full_graph = full_graph # TODO: support full_graph + self.print_tabular_mode = print_tabular_mode + self.input_index = 0 + + def __call__(self, gl: paddlefx.GraphLayer, example_inputs: list): + self.input_index = 0 + if self.print_tabular_mode is not None: + gl.graph.print_tabular(print_mode=self.print_tabular_mode) + return self.compile(gl, example_inputs) + + def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: + symbol_table: dict[str, Any] = {} + example_outputs = gl.forward(*example_inputs) + try: + for node in gl.graph.nodes: + getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) + return self.gen_compiled_func(symbol_table, example_inputs, example_outputs) + except CompilerError as e: + print(f"CompilerError when compiling graph, useing default forward: {e}") + self.input_index = 0 + return gl.forward + except AttributeError as e: + raise AttributeError( + f"AttributeError when compiling graph, check if you use abstract class: {e}" + ) + + def gen_compiled_func( + self, symbol_table: dict[str, Any], dummy_inputs: list, dummy_outputs: Any + ): + raise NotImplementedError("CompilerBase is a abstract class") + + +class DummyCompiler(CompilerBase): + def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: + return gl.forward diff --git a/src/paddlefx/compiler.py b/src/paddlefx/compiler/tvm.py similarity index 62% rename from src/paddlefx/compiler.py rename to src/paddlefx/compiler/tvm.py index 9ab3ffc..47370a9 100644 --- a/src/paddlefx/compiler.py +++ b/src/paddlefx/compiler/tvm.py @@ -1,65 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import paddle import paddle.device import paddlefx +from .base import CompilerBase, CompilerError, paddle_dtype_to_str + if TYPE_CHECKING: from tvm import te -def paddle_dtype_to_str(dtype: paddle.dtype) -> str: - if dtype == paddle.float32: - return "float32" - elif dtype == paddle.float64: - return "float64" - elif dtype == paddle.float16: - return "float16" - elif dtype == paddle.int32: - return "int32" - elif dtype == paddle.int64: - return "int64" - elif dtype == paddle.bool: - return "bool" - else: - raise ValueError(f"Unsupported dtype: {dtype}") - - -class CompilerBase: - def __init__(self, *, full_graph=False, print_tabular: bool = False): - self.full_graph = full_graph # TODO: support full_graph - self.print_tabular = print_tabular - self.input_index = 0 - - def __call__(self, gl: paddlefx.GraphLayer, example_inputs: list): - if self.print_tabular: - gl.graph.print_tabular() - return self.compile(gl, example_inputs) - - def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: - dummy_outputs = gl.forward(*example_inputs) - symbol_table: dict[str, Any] = {} - try: - for node in gl.graph.nodes: - getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) - self.input_index = 0 - return self.gen_compiled_func(symbol_table, example_inputs, dummy_outputs) - except (AttributeError, NotImplementedError) as e: - print(f"AttributeError when compiling graph: {e}") - self.input_index = 0 - return gl.forward - - def gen_compiled_func( - self, symbol_table: dict[str, Any], dummy_inputs: list, dummy_outputs: Any - ): - raise NotImplementedError("CompilerBase is a abstract class") - - class TVMCompiler(CompilerBase): - def gen_compiled_func( + def compile( self, symbol_table: dict[str, te.Tensor], dummy_inputs: list, dummy_outputs: Any ): import tvm @@ -110,6 +65,14 @@ def compile_placeholder( ) self.input_index += 1 + def compile_call_module( + self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list + ): + pass + + target_name = node.target + raise CompilerError(f"Unsupported module: {target_name}") + def compile_call_function( self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list ): diff --git a/src/paddlefx/eval_frame.py b/src/paddlefx/eval_frame.py index c7ca0cc..bb7d5ce 100644 --- a/src/paddlefx/eval_frame.py +++ b/src/paddlefx/eval_frame.py @@ -7,7 +7,7 @@ from typing import Callable from ._eval_frame import set_eval_frame -from .compiler import CompilerBase +from .compiler import DummyCompiler from .convert_frame import convert_frame @@ -46,7 +46,7 @@ def disable(fn=None): def optimize( - model: Callable | None = None, *, backend: Callable = CompilerBase() + model: Callable | None = None, *, backend: Callable = DummyCompiler() ) -> Callable: def _fn(backend: Callable): def __fn(frame: types.FrameType): diff --git a/src/paddlefx/node.py b/src/paddlefx/node.py index dcc68b3..202502a 100644 --- a/src/paddlefx/node.py +++ b/src/paddlefx/node.py @@ -2,17 +2,20 @@ import warnings -from typing import Any, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Union import paddle +if TYPE_CHECKING: + from .graph import Graph + BaseArgumentTypes = Union[str, int, float, bool, complex, paddle.dtype, paddle.Tensor] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] # Nodes represent a definition of a value in our graph of operators. class Node: - def __init__(self, graph, name, op, target, args, kwargs): + def __init__(self, graph: Graph, name, op, target, args, kwargs): self.meta = {} # for storing metadata about the node self.graph = graph self.name = name # unique name of value being created diff --git a/tests/test_basic.py b/tests/test_basic.py index a38f538..0cafdef 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -9,14 +9,23 @@ paddle.seed(0) -def add(x, y): - z = x + y - return z +def binary_operator(a, b): + c = a / b + d = a * b + e = c + d + f = e - a + g = f > e + h = g < f + i = h <= g + j = i >= i + k = j == i + l = j != k + return l def inner_func(x, y): p = paddle.add(x, y) - q = paddle._C_ops.subtract(x, y) + q = paddle._C_ops.subtract(x, y) # type: ignore z = p * q return z / y @@ -33,13 +42,13 @@ def check_func(func, *args): np.testing.assert_allclose(res, out) -def test_add(): - in_a = paddle.rand([1, 224]) - in_b = paddle.rand([1, 224]) - check_func(add, in_a, in_b) +def test_binary_operator(): + in_a = paddle.rand([1, 24]) + in_b = paddle.rand([8, 24]) + check_func(binary_operator, in_a, in_b) -def test_func_add(): +def test_func(): in_a = paddle.rand([8, 8, 16]) in_b = paddle.rand([8, 8, 16]) check_func(func, in_a, in_b) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py deleted file mode 100644 index 42700be..0000000 --- a/tests/test_broadcast.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -import numpy as np -import paddle -import paddle.nn - -import paddlefx - -from paddlefx.compiler import TVMCompiler - - -def add(x, y): - z = x + y - return z - - -def check_func(func, *args): - comiled_func = paddlefx.optimize(func, backend=TVMCompiler(print_tabular=True)) - out = func(*args) - res = comiled_func(*args) - np.testing.assert_allclose(res, out) - - -def test_broadcast_add(): - in_a = paddle.rand([224, 224]) - in_b = paddle.rand([1, 224]) - check_func(add, in_a, in_b) From b54786ebd5596cf246d882596557be7fd423f6ba Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 17:49:34 +0800 Subject: [PATCH 04/20] fix layer --- src/paddlefx/variables/callable.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/paddlefx/variables/callable.py b/src/paddlefx/variables/callable.py index bd457c4..936da82 100644 --- a/src/paddlefx/variables/callable.py +++ b/src/paddlefx/variables/callable.py @@ -76,7 +76,9 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas if fn is layers: target = name break - return obj_cls(node=graph.call_module(target, args, kwargs)) + return TensorVariable( + None, node=graph.call_module(target, args, kwargs) + ) elif fn.__module__.startswith("paddle"): # TODO: support multiple ouputs and containers ot = type(args[0].var) From 10a4fbcb7054af1fef7319a5df54d834626992b4 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 17:52:10 +0800 Subject: [PATCH 05/20] fix bug in tvm compiler --- src/paddlefx/compiler/tvm.py | 2 +- tests/test_compiler_tvm.py | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 tests/test_compiler_tvm.py diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 47370a9..1fb318b 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -14,7 +14,7 @@ class TVMCompiler(CompilerBase): - def compile( + def gen_compiled_func( self, symbol_table: dict[str, te.Tensor], dummy_inputs: list, dummy_outputs: Any ): import tvm diff --git a/tests/test_compiler_tvm.py b/tests/test_compiler_tvm.py new file mode 100644 index 0000000..69e99e5 --- /dev/null +++ b/tests/test_compiler_tvm.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import numpy as np +import paddle +import paddle.nn + +import paddlefx + +from paddlefx.compiler.tvm import TVMCompiler + +paddle.seed(0) + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc = [paddle.nn.Linear(16, 1), paddle.nn.Linear(16, 4)] + + def forward(self, a, b): + c = self.fc[0](a) + d = self.fc[1](b) + e = paddle.add(c, d) + return e + + +net = SimpleNet() + + +def check_func(func, *args): + comiled_func = paddlefx.optimize( + func, backend=TVMCompiler(print_tabular_mode="rich") + ) + out = func(*args) + res = comiled_func(*args) + np.testing.assert_allclose(res, out) + + +def test_simple_net(): + in_a = paddle.rand([8, 16]) + in_b = paddle.rand([8, 16]) + check_func(net, in_a, in_b) From c27ab22b0239016a462f0f46e4b23b7c9ea4d7ce Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 18:10:35 +0800 Subject: [PATCH 06/20] improve tests --- tests/test_basic.py | 10 +--------- tests/test_compiler_tvm.py | 21 ++++++++------------- tests/utils.py | 17 +++++++++++++++++ 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/test_basic.py b/tests/test_basic.py index 0cafdef..1ca532b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,10 +1,9 @@ from __future__ import annotations -import numpy as np import paddle import paddle.nn -import paddlefx +from utils import check_func paddle.seed(0) @@ -35,13 +34,6 @@ def func(a, b): return d -def check_func(func, *args): - comiled_func = paddlefx.optimize(func) - out = func(*args) - res = comiled_func(*args) - np.testing.assert_allclose(res, out) - - def test_binary_operator(): in_a = paddle.rand([1, 24]) in_b = paddle.rand([8, 24]) diff --git a/tests/test_compiler_tvm.py b/tests/test_compiler_tvm.py index 69e99e5..b59e172 100644 --- a/tests/test_compiler_tvm.py +++ b/tests/test_compiler_tvm.py @@ -1,10 +1,9 @@ from __future__ import annotations -import numpy as np import paddle import paddle.nn -import paddlefx +from utils import check_func from paddlefx.compiler.tvm import TVMCompiler @@ -26,16 +25,12 @@ def forward(self, a, b): net = SimpleNet() -def check_func(func, *args): - comiled_func = paddlefx.optimize( - func, backend=TVMCompiler(print_tabular_mode="rich") - ) - out = func(*args) - res = comiled_func(*args) - np.testing.assert_allclose(res, out) - - def test_simple_net(): in_a = paddle.rand([8, 16]) - in_b = paddle.rand([8, 16]) - check_func(net, in_a, in_b) + in_b = paddle.rand([1, 16]) + check_func( + net, + in_a, + in_b, + backend=TVMCompiler(allow_fallback=False, print_tabular_mode="rich"), + ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..a3ab5a5 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import Callable + +import numpy as np + +import paddlefx + + +def check_func(func, *args, backend: Callable | None = None): + if backend is None: + comiled_func = paddlefx.optimize(func) + else: + comiled_func = paddlefx.optimize(func, backend=backend) + out = func(*args) + res = comiled_func(*args) + np.testing.assert_allclose(res, out) From 5a549b8c150669a80db6b7f597d16ea13521fa80 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 18:11:00 +0800 Subject: [PATCH 07/20] fix bug --- src/paddlefx/compiler/base.py | 23 +++++++++++++++++------ src/paddlefx/compiler/tvm.py | 2 +- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index 84eecd0..40386ea 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -30,7 +30,14 @@ class CompilerError(Exception): class CompilerBase: - def __init__(self, *, full_graph=False, print_tabular_mode: str | None = None): + def __init__( + self, + *, + allow_fallback: bool = True, + full_graph=False, + print_tabular_mode: str | None = None, + ): + self.allow_fallback = allow_fallback self.full_graph = full_graph # TODO: support full_graph self.print_tabular_mode = print_tabular_mode self.input_index = 0 @@ -49,13 +56,17 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) return self.gen_compiled_func(symbol_table, example_inputs, example_outputs) except CompilerError as e: - print(f"CompilerError when compiling graph, useing default forward: {e}") - self.input_index = 0 - return gl.forward + if self.allow_fallback: + print( + f"CompilerError when compiling graph, useing default forward: {e}" + ) + self.input_index = 0 + return gl.forward + raise e except AttributeError as e: raise AttributeError( - f"AttributeError when compiling graph, check if you use abstract class: {e}" - ) + f"AttributeError when compiling graph, check if you use abstract class" + ) from e def gen_compiled_func( self, symbol_table: dict[str, Any], dummy_inputs: list, dummy_outputs: Any diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 1fb318b..774fbcb 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -27,7 +27,7 @@ def gen_compiled_func( elif device == "gpu": target = tvm.target.Target(target="cuda", host="llvm") else: - raise ValueError(f"Unsupported device in tvm backend: {device}") + raise CompilerError(f"Unsupported device in tvm backend: {device}") schedule = te.create_schedule(symbol_table["output"].op) tvm_func = tvm.build( schedule, From d12ef93a4ceb65ca29e4fe8add7c11da21b1de5b Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 19:55:30 +0800 Subject: [PATCH 08/20] impl SybolTable and support tuple outputs --- src/paddlefx/compiler/base.py | 47 ++++++++++++++++++++++++++++---- src/paddlefx/compiler/tvm.py | 51 ++++++++++++++++++----------------- 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index 40386ea..d307682 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -1,12 +1,16 @@ from __future__ import annotations -from typing import Any, Callable +import dataclasses + +from typing import Callable, Generic, TypeVar import paddle import paddle.device import paddlefx +T = TypeVar("T") + def paddle_dtype_to_str(dtype: paddle.dtype) -> str: if dtype == paddle.float32: @@ -29,12 +33,45 @@ class CompilerError(Exception): pass +@dataclasses.dataclass +class SybolTable(Generic[T]): + def __init__(self): + self._symbol_table: dict[str, T] = {} + self._inputs: list[T] = [] + self._outputs: tuple[T, ...] = () + + def __getitem__(self, key: str) -> T: + return self._symbol_table[key] + + def __setitem__(self, key: str, value: T): + self._symbol_table[key] = value + + def __iter__(self): + return iter(self._symbol_table.items()) + + @property + def inputs(self) -> tuple[T, ...]: + return tuple(self._inputs) + + def add_input(self, key: str, value: T): + self._inputs.append(value) + self._symbol_table[key] = value + + @property + def outputs(self) -> tuple[T, ...]: + return self._outputs + + @outputs.setter + def outputs(self, value: tuple[T, ...]): + self._outputs = value + + class CompilerBase: def __init__( self, *, - allow_fallback: bool = True, - full_graph=False, + allow_fallback: bool = False, + full_graph: bool = False, print_tabular_mode: str | None = None, ): self.allow_fallback = allow_fallback @@ -49,7 +86,7 @@ def __call__(self, gl: paddlefx.GraphLayer, example_inputs: list): return self.compile(gl, example_inputs) def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: - symbol_table: dict[str, Any] = {} + symbol_table: SybolTable = SybolTable() example_outputs = gl.forward(*example_inputs) try: for node in gl.graph.nodes: @@ -69,7 +106,7 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: ) from e def gen_compiled_func( - self, symbol_table: dict[str, Any], dummy_inputs: list, dummy_outputs: Any + self, symbol_table: SybolTable, dummy_inputs: list, dummy_outputs: list ): raise NotImplementedError("CompilerBase is a abstract class") diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 774fbcb..7bb275b 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import paddle import paddle.device @@ -12,10 +12,15 @@ if TYPE_CHECKING: from tvm import te + from .base import SybolTable + class TVMCompiler(CompilerBase): def gen_compiled_func( - self, symbol_table: dict[str, te.Tensor], dummy_inputs: list, dummy_outputs: Any + self, + symbol_table: SybolTable[te.Tensor], + dummy_inputs: list, + dummy_outputs: list, ): import tvm @@ -31,42 +36,41 @@ def gen_compiled_func( schedule = te.create_schedule(symbol_table["output"].op) tvm_func = tvm.build( schedule, - [ - v - for k, v in symbol_table.items() - if v.name.startswith("input") or k == "output" - ], + [*symbol_table.inputs, *symbol_table.outputs], target, name=symbol_table["output"].name, ) def compiled_func(*args): inputs = [tvm.nd.array(arg.numpy()) for arg in args] - dummy_output = dummy_outputs[0] - output = tvm.nd.empty( - dummy_output.shape, paddle_dtype_to_str(dummy_output.dtype) - ) - tvm_func(*inputs, output) - output = paddle.to_tensor(output.asnumpy()) - return (output,) + + outputs = [ + tvm.nd.empty(out.shape, paddle_dtype_to_str(out.dtype)) + for out in dummy_outputs + ] + tvm_func(*inputs, *outputs) + return tuple(paddle.to_tensor(out.asnumpy()) for out in outputs) compiled_func(*dummy_inputs) return compiled_func def compile_placeholder( - self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list ): from tvm import te - symbol_table[node.name] = te.placeholder( - inputs[self.input_index].shape, - paddle_dtype_to_str(inputs[self.input_index].dtype), - name=f"input_{node.name}", + symbol_table.add_input( + node.name, + te.placeholder( + inputs[self.input_index].shape, + paddle_dtype_to_str(inputs[self.input_index].dtype), + name=f"input_{node.name}", + ), ) self.input_index += 1 def compile_call_module( - self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list ): pass @@ -74,7 +78,7 @@ def compile_call_module( raise CompilerError(f"Unsupported module: {target_name}") def compile_call_function( - self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list ): from tvm import topi @@ -99,7 +103,6 @@ def compile_call_function( raise NotImplementedError(f"Unsupported function: {target_name}") def compile_output( - self, node: paddlefx.Node, symbol_table: dict[str, te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list ): - ret = symbol_table[str(node.args[0][0])] - symbol_table["output"] = ret + symbol_table.outputs = tuple(symbol_table[str(arg)] for arg in node.args[0]) From 33f0832e58fbb53988435fab4cb5efa584b5c61b Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 20:37:46 +0800 Subject: [PATCH 09/20] fix bug in PyCodegen call --- src/paddlefx/codegen.py | 8 ++++---- src/paddlefx/compiler/tvm.py | 4 ++-- src/paddlefx/output_graph.py | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/paddlefx/codegen.py b/src/paddlefx/codegen.py index 21f769c..fd04a5f 100644 --- a/src/paddlefx/codegen.py +++ b/src/paddlefx/codegen.py @@ -123,10 +123,10 @@ def make_call_generated_code(self, fn_name: str): self.extend_output(create_call_function(len(placeholders), False)) def call(self, vars: VariableStack[VariableBase]): - for var in vars: - self.call_one(var) + for i, var in enumerate(vars): + self.call_one(i, var) - def call_one(self, value: VariableBase): + def call_one(self, index: int, value: VariableBase): """Generate code such that top-of-stack (TOS) is set to value.""" output = self.instructions graph_outputs = self.graph_outputs @@ -149,7 +149,7 @@ def call_one(self, value: VariableBase): output.append(self.create_load(self.graph_output_var)) # TODO: rm hardcode - output.append(self.create_load_const(0)) + output.append(self.create_load_const(index)) output.append(create_instruction("BINARY_SUBSCR")) elif value.var == None: output.append(self.create_load_const(None)) diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 7bb275b..a87bcb6 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -33,12 +33,12 @@ def gen_compiled_func( target = tvm.target.Target(target="cuda", host="llvm") else: raise CompilerError(f"Unsupported device in tvm backend: {device}") - schedule = te.create_schedule(symbol_table["output"].op) + schedule = te.create_schedule([out.op for out in symbol_table.outputs]) tvm_func = tvm.build( schedule, [*symbol_table.inputs, *symbol_table.outputs], target, - name=symbol_table["output"].name, + name="tvm_func", ) def compiled_func(*args): diff --git a/src/paddlefx/output_graph.py b/src/paddlefx/output_graph.py index 3f96a83..084fffb 100644 --- a/src/paddlefx/output_graph.py +++ b/src/paddlefx/output_graph.py @@ -105,7 +105,6 @@ def apply_compiler(self, tx: PyEvalBase, rv: list[VariableBase], root): gl = GraphLayer(root, self.graph) compiled_fn_name = f"__compiled_fn_{next(_compiled_fn_counter)}" - # TODO: add inputs compiled_fn = self.compiler_fn(gl, self.example_inputs()) log_code( compiled_fn.__code__, From a5dc37e7912705082019874c486fdd56f02efe91 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 20:45:58 +0800 Subject: [PATCH 10/20] improve tests --- examples/TODO/resnet_dynamo.py | 11 +--- examples/simple_dynamo.py | 109 ++++++++++++++++----------------- src/paddlefx/codegen.py | 3 +- tests/test_basic.py | 8 ++- tests/test_compiler_tvm.py | 2 +- 5 files changed, 63 insertions(+), 70 deletions(-) diff --git a/examples/TODO/resnet_dynamo.py b/examples/TODO/resnet_dynamo.py index 3563990..0f7b051 100644 --- a/examples/TODO/resnet_dynamo.py +++ b/examples/TODO/resnet_dynamo.py @@ -9,16 +9,11 @@ import paddlefx +from paddlefx.compiler.tvm import TVMCompiler -def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None): - print("my_compiler() called with FX graph:") - print(gl.get_source()) - gl.graph.print_tabular(print_mode="rich") - return gl.forward - - +compiler = TVMCompiler(full_graph=True, print_tabular_mode="rich") net = resnet18() -optimized_net = paddlefx.optimize(backend=my_compiler)(net) +optimized_net = paddlefx.optimize(net, backend=compiler) x = paddle.rand([1, 3, 224, 224]) out = net(x) diff --git a/examples/simple_dynamo.py b/examples/simple_dynamo.py index 695998d..e4b91fe 100644 --- a/examples/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -1,82 +1,77 @@ from __future__ import annotations +import logging + import numpy as np import paddle import paddle.nn import paddlefx -from paddlefx.compiler import TVMCompiler - -# logging.getLogger().setLevel(logging.DEBUG) - - -@paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich")) -def add(a, b): - print('\tcall add') - c = a + b - return c - - -@paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich")) -def func(a, b): - print('\tcall func') - c = add(a, b) - d = add(c, c) - return d +from paddlefx.compiler import DummyCompiler, TVMCompiler +logging.getLogger().setLevel(logging.DEBUG) +dummy_compier = DummyCompiler(full_graph=True, print_tabular_mode="rich") +compiler = TVMCompiler(full_graph=True, print_tabular_mode="rich") -in_a = paddle.rand([3, 4]) -in_b = paddle.rand([3, 4]) -out = paddle.add(in_a, in_b) -res = add(in_a, in_b) -np.testing.assert_equal(res.numpy(), out.numpy()) +def inner_func(x, y): + p = x + y + q = paddle._C_ops.subtract(x, y) # type: ignore + u = paddle._C_ops.subtract(x, y) # type: ignore + print(1) + z = p * q * u + return z / y -dtype = 'float32' -in_a = paddle.to_tensor([1, 2], dtype=dtype) -in_b = paddle.to_tensor([0, 1], dtype=dtype) +def breakraph_func(a, b): + d = inner_func(a, b) + d = inner_func(a, b) + # print("call func") + q = inner_func(a, b) + return d, q -def inplace(a, b): - # print('\tcall inplace') - a -= b - a += b - a *= b - a /= b - a **= b - a @= b - return a +def check_func(func, *args, backend: None = None): + if backend is None: + comiled_func = paddlefx.optimize(func) + else: + comiled_func = paddlefx.optimize(func, backend=backend) + out = func(*args) + res = comiled_func(*args) + if isinstance(out, tuple): + for i in range(len(res)): + np.testing.assert_allclose(res[i], out[i]) + else: + np.testing.assert_allclose(res, out) -optimized_foo = paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich"))( - inplace -) +in_a = paddle.rand([8, 8, 16]) +in_b = paddle.rand([8, 1, 16]) +check_func(inner_func, in_a, in_b, backend=compiler) -original_res = inplace(in_a, in_b) -optimized_res = optimized_foo(in_a, in_b) -np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) +# dtype = 'float32' +# in_a = paddle.to_tensor([1, 2], dtype=dtype) +# in_b = paddle.to_tensor([0, 1], dtype=dtype) -class ExampleNet(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.fc = [paddle.nn.Linear(1, 1), paddle.nn.Linear(1, 1)] +# def inplace(a, b): +# # print('\tcall inplace') +# a -= b +# a += b +# a *= b +# a /= b +# a **= b +# a @= b +# return a - def forward(self, a, b): - c = self.fc[0](a[0]) - d = self.fc[1](b[0]) - e = paddle.add(c, d) - return e +# optimized_foo = paddlefx.optimize(backend=compiler)( +# inplace +# ) -net = ExampleNet() -optimized_func = paddlefx.optimize(backend=TVMCompiler(print_tabular_mode="rich"))(net) +# original_res = inplace(in_a, in_b) +# optimized_res = optimized_foo(in_a, in_b) -original_res = net(in_a, in_b) -optimized_res = optimized_func(in_a, in_b) -optimized_res = optimized_func(in_a, in_b) -# TODO(zrr1999): `optimized_res` is the result of running the converted bytecode in the future. -np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) +# np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) diff --git a/src/paddlefx/codegen.py b/src/paddlefx/codegen.py index fd04a5f..4119eb1 100644 --- a/src/paddlefx/codegen.py +++ b/src/paddlefx/codegen.py @@ -124,7 +124,8 @@ def make_call_generated_code(self, fn_name: str): def call(self, vars: VariableStack[VariableBase]): for i, var in enumerate(vars): - self.call_one(i, var) + # TODO: Maybe self.call_one(i, var)? + self.call_one(0, var) def call_one(self, index: int, value: VariableBase): """Generate code such that top-of-stack (TOS) is set to value.""" diff --git a/tests/test_basic.py b/tests/test_basic.py index 1ca532b..f519984 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -24,12 +24,14 @@ def binary_operator(a, b): def inner_func(x, y): p = paddle.add(x, y) + print("call inner") q = paddle._C_ops.subtract(x, y) # type: ignore z = p * q return z / y -def func(a, b): +def breakraph_func(a, b): + print("call func") d = inner_func(a, b) return d @@ -40,7 +42,7 @@ def test_binary_operator(): check_func(binary_operator, in_a, in_b) -def test_func(): +def test_breakraph_func(): in_a = paddle.rand([8, 8, 16]) in_b = paddle.rand([8, 8, 16]) - check_func(func, in_a, in_b) + check_func(inner_func, in_a, in_b) diff --git a/tests/test_compiler_tvm.py b/tests/test_compiler_tvm.py index b59e172..25caeb8 100644 --- a/tests/test_compiler_tvm.py +++ b/tests/test_compiler_tvm.py @@ -32,5 +32,5 @@ def test_simple_net(): net, in_a, in_b, - backend=TVMCompiler(allow_fallback=False, print_tabular_mode="rich"), + backend=TVMCompiler(full_graph=True, print_tabular_mode="rich"), ) From 29965f0054ab3336451b33f190d20cfb3e624704 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 20:55:40 +0800 Subject: [PATCH 11/20] fix bug in codegen --- examples/simple_dynamo.py | 3 +-- src/paddlefx/codegen.py | 11 ++++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/simple_dynamo.py b/examples/simple_dynamo.py index e4b91fe..6b7c35c 100644 --- a/examples/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -18,9 +18,8 @@ def inner_func(x, y): p = x + y q = paddle._C_ops.subtract(x, y) # type: ignore - u = paddle._C_ops.subtract(x, y) # type: ignore print(1) - z = p * q * u + z = p * q return z / y diff --git a/src/paddlefx/codegen.py b/src/paddlefx/codegen.py index 4119eb1..3f85bf6 100644 --- a/src/paddlefx/codegen.py +++ b/src/paddlefx/codegen.py @@ -123,11 +123,11 @@ def make_call_generated_code(self, fn_name: str): self.extend_output(create_call_function(len(placeholders), False)) def call(self, vars: VariableStack[VariableBase]): - for i, var in enumerate(vars): - # TODO: Maybe self.call_one(i, var)? - self.call_one(0, var) + self.tensor_index = 0 # TDDO: rm this + for var in vars: + self.call_one(var) - def call_one(self, index: int, value: VariableBase): + def call_one(self, value: VariableBase): """Generate code such that top-of-stack (TOS) is set to value.""" output = self.instructions graph_outputs = self.graph_outputs @@ -150,7 +150,8 @@ def call_one(self, index: int, value: VariableBase): output.append(self.create_load(self.graph_output_var)) # TODO: rm hardcode - output.append(self.create_load_const(index)) + output.append(self.create_load_const(self.tensor_index)) + self.tensor_index += 1 output.append(create_instruction("BINARY_SUBSCR")) elif value.var == None: output.append(self.create_load_const(None)) From c5a43502f889f1698b9961657b9bd892203f0ff8 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 20:59:25 +0800 Subject: [PATCH 12/20] fix --- examples/resnet_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/resnet_trace.py b/examples/resnet_trace.py index 84ff910..f15ee81 100644 --- a/examples/resnet_trace.py +++ b/examples/resnet_trace.py @@ -12,7 +12,7 @@ net = resnet18() traced_layer = symbolic_trace(net) -example_input = paddle.rand([2, 3, 224, 224]) +example_input = paddle.rand([2, 3, 24, 24]) orig_output = net(example_input) traced_output = traced_layer(example_input) From ef684c94945c9d057198ea6373b4a526d99939ed Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 22:06:28 +0800 Subject: [PATCH 13/20] use 1234 seed --- examples/resnet_trace.py | 2 +- examples/simple_compiler.py | 2 +- examples/targets/target_3_add_paddle.py | 2 +- tests/test_basic.py | 2 +- tests/test_compiler_tvm.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/resnet_trace.py b/examples/resnet_trace.py index f15ee81..4614c84 100644 --- a/examples/resnet_trace.py +++ b/examples/resnet_trace.py @@ -7,7 +7,7 @@ from paddlefx import symbolic_trace -paddle.seed(0) +paddle.seed(1234) net = resnet18() traced_layer = symbolic_trace(net) diff --git a/examples/simple_compiler.py b/examples/simple_compiler.py index 427c3b9..5b039d2 100644 --- a/examples/simple_compiler.py +++ b/examples/simple_compiler.py @@ -9,7 +9,7 @@ from paddlefx.compiler import TVMCompiler -paddle.seed(0) +paddle.seed(1234) # logging.getLogger().setLevel(logging.DEBUG) diff --git a/examples/targets/target_3_add_paddle.py b/examples/targets/target_3_add_paddle.py index e07ba2a..eb7afb8 100644 --- a/examples/targets/target_3_add_paddle.py +++ b/examples/targets/target_3_add_paddle.py @@ -15,7 +15,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(message)s") # logging.basicConfig(level=logging.INFO, format="%(message)s") -paddle.seed(0) +paddle.seed(1234) def func(x, y): diff --git a/tests/test_basic.py b/tests/test_basic.py index f519984..715647d 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -5,7 +5,7 @@ from utils import check_func -paddle.seed(0) +paddle.seed(1234) def binary_operator(a, b): diff --git a/tests/test_compiler_tvm.py b/tests/test_compiler_tvm.py index 25caeb8..0d1986d 100644 --- a/tests/test_compiler_tvm.py +++ b/tests/test_compiler_tvm.py @@ -7,7 +7,7 @@ from paddlefx.compiler.tvm import TVMCompiler -paddle.seed(0) +paddle.seed(1234) class SimpleNet(paddle.nn.Layer): From 9b05c21fceecdd9aa0895445e3cab263d42f7482 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 22:16:52 +0800 Subject: [PATCH 14/20] fix typos --- src/paddlefx/compiler/base.py | 6 +++--- src/paddlefx/compiler/tvm.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index d307682..0a8e258 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -34,7 +34,7 @@ class CompilerError(Exception): @dataclasses.dataclass -class SybolTable(Generic[T]): +class SymbolTable(Generic[T]): def __init__(self): self._symbol_table: dict[str, T] = {} self._inputs: list[T] = [] @@ -86,7 +86,7 @@ def __call__(self, gl: paddlefx.GraphLayer, example_inputs: list): return self.compile(gl, example_inputs) def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: - symbol_table: SybolTable = SybolTable() + symbol_table: SymbolTable = SymbolTable() example_outputs = gl.forward(*example_inputs) try: for node in gl.graph.nodes: @@ -106,7 +106,7 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: ) from e def gen_compiled_func( - self, symbol_table: SybolTable, dummy_inputs: list, dummy_outputs: list + self, symbol_table: SymbolTable, dummy_inputs: list, dummy_outputs: list ): raise NotImplementedError("CompilerBase is a abstract class") diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index a87bcb6..180ba9c 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -12,13 +12,13 @@ if TYPE_CHECKING: from tvm import te - from .base import SybolTable + from .base import SymbolTable class TVMCompiler(CompilerBase): def gen_compiled_func( self, - symbol_table: SybolTable[te.Tensor], + symbol_table: SymbolTable[te.Tensor], dummy_inputs: list, dummy_outputs: list, ): @@ -55,7 +55,7 @@ def compiled_func(*args): return compiled_func def compile_placeholder( - self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list ): from tvm import te @@ -70,7 +70,7 @@ def compile_placeholder( self.input_index += 1 def compile_call_module( - self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list ): pass @@ -78,7 +78,7 @@ def compile_call_module( raise CompilerError(f"Unsupported module: {target_name}") def compile_call_function( - self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list ): from tvm import topi @@ -103,6 +103,6 @@ def compile_call_function( raise NotImplementedError(f"Unsupported function: {target_name}") def compile_output( - self, node: paddlefx.Node, symbol_table: SybolTable[te.Tensor], inputs: list + self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list ): symbol_table.outputs = tuple(symbol_table[str(arg)] for arg in node.args[0]) From 6b79fb8b7ba9c199128ff16a84f4cd5a7bfd0954 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 1 Oct 2023 23:03:21 +0800 Subject: [PATCH 15/20] fix --- examples/simple_compiler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/simple_compiler.py b/examples/simple_compiler.py index 5b039d2..dc56d3f 100644 --- a/examples/simple_compiler.py +++ b/examples/simple_compiler.py @@ -23,9 +23,6 @@ def inner_func(x, y): def func(a, b): d = inner_func(a, b) - d = inner_func(a, d) - d = inner_func(d, a) - d = inner_func(a, d) return d From 904ae8f87ecc80ccc8efbb3e46ecd87cc143d8db Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 2 Oct 2023 15:09:27 +0800 Subject: [PATCH 16/20] rm dummy_outputs --- src/paddlefx/compiler/base.py | 5 ++--- src/paddlefx/compiler/tvm.py | 9 ++------- tests/test_basic.py | 20 +++++++++++++++++--- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index 0a8e258..530e9fd 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -87,11 +87,10 @@ def __call__(self, gl: paddlefx.GraphLayer, example_inputs: list): def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: symbol_table: SymbolTable = SymbolTable() - example_outputs = gl.forward(*example_inputs) try: for node in gl.graph.nodes: getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) - return self.gen_compiled_func(symbol_table, example_inputs, example_outputs) + return self.gen_compiled_func(gl, symbol_table, example_inputs) except CompilerError as e: if self.allow_fallback: print( @@ -106,7 +105,7 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: ) from e def gen_compiled_func( - self, symbol_table: SymbolTable, dummy_inputs: list, dummy_outputs: list + self, gl: paddlefx.GraphLayer, symbol_table: SymbolTable, dummy_inputs: list ): raise NotImplementedError("CompilerBase is a abstract class") diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 180ba9c..ff81ce5 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -18,9 +18,9 @@ class TVMCompiler(CompilerBase): def gen_compiled_func( self, + gl: paddlefx.GraphLayer, symbol_table: SymbolTable[te.Tensor], dummy_inputs: list, - dummy_outputs: list, ): import tvm @@ -43,15 +43,12 @@ def gen_compiled_func( def compiled_func(*args): inputs = [tvm.nd.array(arg.numpy()) for arg in args] - outputs = [ - tvm.nd.empty(out.shape, paddle_dtype_to_str(out.dtype)) - for out in dummy_outputs + tvm.nd.empty(out.shape, out.dtype) for out in symbol_table.outputs ] tvm_func(*inputs, *outputs) return tuple(paddle.to_tensor(out.asnumpy()) for out in outputs) - compiled_func(*dummy_inputs) return compiled_func def compile_placeholder( @@ -72,8 +69,6 @@ def compile_placeholder( def compile_call_module( self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list ): - pass - target_name = node.target raise CompilerError(f"Unsupported module: {target_name}") diff --git a/tests/test_basic.py b/tests/test_basic.py index 715647d..b70ec90 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -5,8 +5,6 @@ from utils import check_func -paddle.seed(1234) - def binary_operator(a, b): c = a / b @@ -22,6 +20,16 @@ def binary_operator(a, b): return l +def binary_inplace_operator(a, b): + a -= b + a += b + a *= b + a /= b + a **= b + a @= b.T + return a + + def inner_func(x, y): p = paddle.add(x, y) print("call inner") @@ -42,7 +50,13 @@ def test_binary_operator(): check_func(binary_operator, in_a, in_b) +def test_binary_inplace_operator(): + in_a = paddle.rand([1, 24]) + in_b = paddle.rand([8, 24]) + check_func(binary_inplace_operator, in_a, in_b) + + def test_breakraph_func(): - in_a = paddle.rand([8, 8, 16]) + in_a = paddle.rand([8, 8, 1]) in_b = paddle.rand([8, 8, 16]) check_func(inner_func, in_a, in_b) From 843c36971abc778cf5f58ca815691c7b43fd7c05 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 2 Oct 2023 15:11:10 +0800 Subject: [PATCH 17/20] rm dymmy and gl --- src/paddlefx/compiler/base.py | 6 ++---- src/paddlefx/compiler/tvm.py | 7 +------ 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index 530e9fd..6cc8b90 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -90,7 +90,7 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: try: for node in gl.graph.nodes: getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) - return self.gen_compiled_func(gl, symbol_table, example_inputs) + return self.gen_compiled_func(symbol_table) except CompilerError as e: if self.allow_fallback: print( @@ -104,9 +104,7 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: f"AttributeError when compiling graph, check if you use abstract class" ) from e - def gen_compiled_func( - self, gl: paddlefx.GraphLayer, symbol_table: SymbolTable, dummy_inputs: list - ): + def gen_compiled_func(self, symbol_table: SymbolTable): raise NotImplementedError("CompilerBase is a abstract class") diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index ff81ce5..cc24e19 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -16,12 +16,7 @@ class TVMCompiler(CompilerBase): - def gen_compiled_func( - self, - gl: paddlefx.GraphLayer, - symbol_table: SymbolTable[te.Tensor], - dummy_inputs: list, - ): + def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor]): import tvm from tvm import te From 45995ed85a6ce5ca4d32900b10cbe1b0fc6b51c7 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 2 Oct 2023 19:11:31 +0800 Subject: [PATCH 18/20] support call module(Linear Conv) --- examples/simple_dynamo.py | 56 +++++----------- requirements_dev.txt | 2 +- src/paddlefx/compiler/base.py | 38 +++++++---- src/paddlefx/compiler/tvm.py | 116 +++++++++++++++++++++++++++++++--- tests/test_compiler_tvm.py | 12 ++-- tests/utils.py | 6 +- 6 files changed, 160 insertions(+), 70 deletions(-) diff --git a/examples/simple_dynamo.py b/examples/simple_dynamo.py index 6b7c35c..3937c21 100644 --- a/examples/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -15,22 +15,6 @@ compiler = TVMCompiler(full_graph=True, print_tabular_mode="rich") -def inner_func(x, y): - p = x + y - q = paddle._C_ops.subtract(x, y) # type: ignore - print(1) - z = p * q - return z / y - - -def breakraph_func(a, b): - d = inner_func(a, b) - d = inner_func(a, b) - # print("call func") - q = inner_func(a, b) - return d, q - - def check_func(func, *args, backend: None = None): if backend is None: comiled_func = paddlefx.optimize(func) @@ -42,35 +26,25 @@ def check_func(func, *args, backend: None = None): for i in range(len(res)): np.testing.assert_allclose(res[i], out[i]) else: - np.testing.assert_allclose(res, out) - - -in_a = paddle.rand([8, 8, 16]) -in_b = paddle.rand([8, 1, 16]) -check_func(inner_func, in_a, in_b, backend=compiler) - + np.testing.assert_allclose(res, out, rtol=1e-5, atol=1e-6) -# dtype = 'float32' -# in_a = paddle.to_tensor([1, 2], dtype=dtype) -# in_b = paddle.to_tensor([0, 1], dtype=dtype) +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc1 = paddle.nn.Linear(16, 4) + self.fc2 = paddle.nn.Linear(16, 1) -# def inplace(a, b): -# # print('\tcall inplace') -# a -= b -# a += b -# a *= b -# a /= b -# a **= b -# a @= b -# return a + def forward(self, a, b): + c = self.fc1(a) + d = self.fc2(b) + e = paddle.add(c, d) + return e -# optimized_foo = paddlefx.optimize(backend=compiler)( -# inplace -# ) +net = SimpleNet() -# original_res = inplace(in_a, in_b) -# optimized_res = optimized_foo(in_a, in_b) -# np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) +in_a = paddle.rand([8, 16]) +in_b = paddle.rand([8, 16]) +check_func(net, in_a, in_b, backend=compiler) diff --git a/requirements_dev.txt b/requirements_dev.txt index 68b641f..4ef1b5e 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -13,7 +13,7 @@ pre-commit>=3.0.0 tabulate==0.9.0 -apache-tvm>=0.11.1 +apache-tvm==0.14.dev214 # debug python with paddle opencv-python-headless diff --git a/src/paddlefx/compiler/base.py b/src/paddlefx/compiler/base.py index 6cc8b90..46cf671 100644 --- a/src/paddlefx/compiler/base.py +++ b/src/paddlefx/compiler/base.py @@ -9,7 +9,8 @@ import paddlefx -T = TypeVar("T") +PlaceholderT = TypeVar("PlaceholderT") +ValueT = TypeVar("ValueT") def paddle_dtype_to_str(dtype: paddle.dtype) -> str: @@ -34,35 +35,44 @@ class CompilerError(Exception): @dataclasses.dataclass -class SymbolTable(Generic[T]): +class SymbolTable(Generic[PlaceholderT, ValueT]): def __init__(self): - self._symbol_table: dict[str, T] = {} - self._inputs: list[T] = [] - self._outputs: tuple[T, ...] = () + self._symbol_table: dict[str, PlaceholderT] = {} + self._inputs: list[PlaceholderT] = [] + self._weights: dict[str, tuple[PlaceholderT, ValueT]] = {} + self._outputs: tuple[PlaceholderT, ...] = () - def __getitem__(self, key: str) -> T: + def __getitem__(self, key: str) -> PlaceholderT: return self._symbol_table[key] - def __setitem__(self, key: str, value: T): + def __setitem__(self, key: str, value: PlaceholderT): self._symbol_table[key] = value def __iter__(self): return iter(self._symbol_table.items()) @property - def inputs(self) -> tuple[T, ...]: - return tuple(self._inputs) + def inputs(self) -> tuple[PlaceholderT, ...]: + return tuple(self._inputs + [value[0] for value in self._weights.values()]) - def add_input(self, key: str, value: T): + def add_input(self, key: str, value: PlaceholderT): self._inputs.append(value) self._symbol_table[key] = value @property - def outputs(self) -> tuple[T, ...]: + def weights(self) -> tuple[ValueT, ...]: + return tuple(value[1] for value in self._weights.values()) + + def add_weight(self, key: str, value: tuple[PlaceholderT, ValueT]): + self._weights[key] = value + self._symbol_table[key] = value[0] + + @property + def outputs(self) -> tuple[PlaceholderT, ...]: return self._outputs @outputs.setter - def outputs(self, value: tuple[T, ...]): + def outputs(self, value: tuple[PlaceholderT, ...]): self._outputs = value @@ -89,7 +99,9 @@ def compile(self, gl: paddlefx.GraphLayer, example_inputs: list) -> Callable: symbol_table: SymbolTable = SymbolTable() try: for node in gl.graph.nodes: - getattr(self, f"compile_{node.op}")(node, symbol_table, example_inputs) + getattr(self, f"compile_{node.op}")( + gl, node, symbol_table, example_inputs + ) return self.gen_compiled_func(symbol_table) except CompilerError as e: if self.allow_fallback: diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index cc24e19..28ed8eb 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -5,6 +5,8 @@ import paddle import paddle.device +from paddle import nn + import paddlefx from .base import CompilerBase, CompilerError, paddle_dtype_to_str @@ -16,7 +18,7 @@ class TVMCompiler(CompilerBase): - def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor]): + def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor, paddle.Tensor]): import tvm from tvm import te @@ -36,18 +38,25 @@ def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor]): name="tvm_func", ) + weights = [tvm.nd.array(p.numpy()) for p in symbol_table.weights] + def compiled_func(*args): inputs = [tvm.nd.array(arg.numpy()) for arg in args] outputs = [ tvm.nd.empty(out.shape, out.dtype) for out in symbol_table.outputs ] - tvm_func(*inputs, *outputs) + tvm_func(*inputs, *weights, *outputs) + len(outputs) return tuple(paddle.to_tensor(out.asnumpy()) for out in outputs) return compiled_func def compile_placeholder( - self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list + self, + gl: paddlefx.GraphLayer, + node: paddlefx.Node, + symbol_table: SymbolTable[te.Tensor, paddle.Tensor], + inputs: list, ): from tvm import te @@ -62,13 +71,100 @@ def compile_placeholder( self.input_index += 1 def compile_call_module( - self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list + self, + gl: paddlefx.GraphLayer, + node: paddlefx.Node, + symbol_table: SymbolTable[te.Tensor, paddle.Tensor], + inputs: list, ): - target_name = node.target - raise CompilerError(f"Unsupported module: {target_name}") + from tvm import te, topi + + module = gl + names = node.target.split(".") + while len(names) > 0: + module = getattr(module, names.pop(0)) + if isinstance(module, nn.Linear): + # TODO: pre-load weight and bias + symbol_table.add_weight( + f"{node.name}_weight", + ( + te.placeholder( + module.weight.T.shape, + paddle_dtype_to_str(module.weight.dtype), + name=f"params_{node.name}_weight", + ), + module.weight.T, + ), + ) + symbol_table.add_weight( + f"{node.name}_bias", + ( + te.placeholder( + module.bias.shape, + paddle_dtype_to_str(module.bias.dtype), + name=f"params_{node.name}_bias", + ), + module.bias, + ), + ) + symbol_table[node.name] = topi.nn.dense( # type: ignore + symbol_table[str(node.args[0])], + symbol_table[f"{node.name}_weight"], + symbol_table[f"{node.name}_bias"], + ) + elif isinstance(module, nn.Conv2D): + symbol_table.add_weight( + f"{node.name}_weight", + ( + te.placeholder( + module.weight.shape, + paddle_dtype_to_str(module.weight.dtype), + name=f"params_{node.name}_weight", + ), + module.weight, + ), + ) + + if module.bias is not None: + bias = module.bias.reshape((1, -1, 1, 1)) + symbol_table.add_weight( + f"{node.name}_bias", + ( + te.placeholder( + bias.shape, + paddle_dtype_to_str(bias.dtype), + name=f"params_{node.name}_bias", + ), + bias, + ), + ) + symbol_table[node.name] = topi.add( + topi.nn.conv2d( + symbol_table[str(node.args[0])], + symbol_table[f"{node.name}_weight"], + module._stride, + module._updated_padding, + module._dilation, + ), + symbol_table[f"{node.name}_bias"], + ) + else: + symbol_table[node.name] = topi.nn.conv2d( # type: ignore + symbol_table[str(node.args[0])], + symbol_table[f"{node.name}_weight"], + module._stride, + module._updated_padding, + module._dilation, + ) + else: + raise CompilerError(f"Unsupported module: {module.__class__.__name__}") def compile_call_function( - self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list + self, + gl: paddlefx.GraphLayer, + node: paddlefx.Node, + symbol_table: SymbolTable[te.Tensor, paddle.Tensor], + inputs: list, ): from tvm import topi @@ -93,6 +189,10 @@ def compile_call_function( raise NotImplementedError(f"Unsupported function: {target_name}") def compile_output( - self, node: paddlefx.Node, symbol_table: SymbolTable[te.Tensor], inputs: list + self, + gl: paddlefx.GraphLayer, + node: paddlefx.Node, + symbol_table: SymbolTable[te.Tensor, paddle.Tensor], + inputs: list, ): symbol_table.outputs = tuple(symbol_table[str(arg)] for arg in node.args[0]) diff --git a/tests/test_compiler_tvm.py b/tests/test_compiler_tvm.py index 0d1986d..4d71ca1 100644 --- a/tests/test_compiler_tvm.py +++ b/tests/test_compiler_tvm.py @@ -13,13 +13,13 @@ class SimpleNet(paddle.nn.Layer): def __init__(self): super().__init__() - self.fc = [paddle.nn.Linear(16, 1), paddle.nn.Linear(16, 4)] + self.fc1 = paddle.nn.Linear(16, 4) + self.conv1 = paddle.nn.Conv2D(1, 4, 3, 1) def forward(self, a, b): - c = self.fc[0](a) - d = self.fc[1](b) - e = paddle.add(c, d) - return e + c = self.fc1(a) + d = self.conv1(b) + return c, d net = SimpleNet() @@ -27,7 +27,7 @@ def forward(self, a, b): def test_simple_net(): in_a = paddle.rand([8, 16]) - in_b = paddle.rand([1, 16]) + in_b = paddle.rand([8, 1, 4, 4]) check_func( net, in_a, diff --git a/tests/utils.py b/tests/utils.py index a3ab5a5..356b7e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,4 +14,8 @@ def check_func(func, *args, backend: Callable | None = None): comiled_func = paddlefx.optimize(func, backend=backend) out = func(*args) res = comiled_func(*args) - np.testing.assert_allclose(res, out) + if isinstance(out, tuple): + for i in range(len(res)): + np.testing.assert_allclose(res[i], out[i], rtol=1e-5, atol=1e-6) + else: + np.testing.assert_allclose(res, out, rtol=1e-5, atol=1e-6) From 93d9ce72e1a84651dafdfae16dd1c5bce4c884de Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 2 Oct 2023 20:48:03 +0800 Subject: [PATCH 19/20] mv CodeCacheManager --- src/paddlefx/cache_manager.py | 44 ++++++++++++++++++++++++++++++++ src/paddlefx/convert_frame.py | 5 ++-- src/paddlefx/pyeval.py | 47 +---------------------------------- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/src/paddlefx/cache_manager.py b/src/paddlefx/cache_manager.py index a0f36fa..3b7818b 100644 --- a/src/paddlefx/cache_manager.py +++ b/src/paddlefx/cache_manager.py @@ -14,3 +14,47 @@ class GuardedCode: code: types.CodeType guard_fn: GuardFunction + + +class CodeCacheManager: + cache_dict: dict[types.CodeType, GuardedCodes] = {} + + @classmethod + def add_cache(cls, code: types.CodeType, guarded_code: GuardedCode): + cls.cache_dict.setdefault(code, []) + cls.cache_dict[code].append(guarded_code) + + @classmethod + def get_cache(cls, frame: types.FrameType) -> GuardedCode | None: + code: types.CodeType = frame.f_code + if code not in cls.cache_dict: + print(f"Firstly call {code}\n") + return None + return cls.lookup(frame, cls.cache_dict[code]) + + @classmethod + def clear_cache(cls): + cls.cache_dict.clear() + + @classmethod + def lookup( + cls, frame: types.FrameType, guarded_codes: GuardedCodes + ) -> GuardedCode | None: + for guarded_code in guarded_codes: + try: + guard_fn = guarded_code.guard_fn + if guard_fn(frame): + print( + f"[Cache]: Cache hit, GuardFunction is {guard_fn}\n", + ) + return guarded_code + else: + print( + f"[Cache]: Cache miss, GuardFunction is {guard_fn}\n", + ) + except Exception as e: + print(f"[Cache]: GuardFunction function error: {e}\n") + continue + + print("[Cache]: all guards missed\n") + return None diff --git a/src/paddlefx/convert_frame.py b/src/paddlefx/convert_frame.py index 2d4a642..ef3039f 100644 --- a/src/paddlefx/convert_frame.py +++ b/src/paddlefx/convert_frame.py @@ -6,9 +6,9 @@ from typing import TYPE_CHECKING, Callable from .bytecode_transformation import Instruction, transform_code_object -from .cache_manager import GuardedCode +from .cache_manager import CodeCacheManager, GuardedCode from .paddle_utils import Tensor, skip_paddle_filename, skip_paddle_frame -from .pyeval import CodeCacheManager, PyEval +from .pyeval import PyEval from .utils import log_bytecode, log_code if TYPE_CHECKING: @@ -33,7 +33,6 @@ def convert_frame(frame: types.FrameType, compiler_fn: Callable) -> GuardedCode if skip_frame(frame): logging.debug(f"skip_frame: {frame}") return None - # TODO: guard_fn is not declared in this scope guard_fn = None diff --git a/src/paddlefx/pyeval.py b/src/paddlefx/pyeval.py index 3b624e1..a173923 100644 --- a/src/paddlefx/pyeval.py +++ b/src/paddlefx/pyeval.py @@ -21,7 +21,6 @@ transform_code_object, unique_id, ) -from .cache_manager import GuardedCode from .codegen import PyCodegen from .output_graph import OutputGraph from .paddle_utils import TensorType @@ -34,51 +33,7 @@ if TYPE_CHECKING: # import opcode - from .cache_manager import GuardedCodes - - -class CodeCacheManager: - cache_dict: dict[types.CodeType, GuardedCodes] = {} - - @classmethod - def add_cache(cls, code: types.CodeType, guarded_code: GuardedCode): - cls.cache_dict.setdefault(code, []) - cls.cache_dict[code].append(guarded_code) - - @classmethod - def get_cache(cls, frame: types.FrameType) -> GuardedCode | None: - code: types.CodeType = frame.f_code - if code not in cls.cache_dict: - print(f"Firstly call {code}\n") - return None - return cls.lookup(frame, cls.cache_dict[code]) - - @classmethod - def clear_cache(cls): - cls.cache_dict.clear() - - @classmethod - def lookup( - cls, frame: types.FrameType, guarded_codes: GuardedCodes - ) -> GuardedCode | None: - for guarded_code in guarded_codes: - try: - guard_fn = guarded_code.guard_fn - if guard_fn(frame): - print( - f"[Cache]: Cache hit, GuardFunction is {guard_fn}\n", - ) - return guarded_code - else: - print( - f"[Cache]: Cache miss, GuardFunction is {guard_fn}\n", - ) - except Exception as e: - print(f"[Cache]: GuardFunction function error: {e}\n") - continue - - print("[Cache]: all guards missed\n") - return None + pass def tos_op_wrapper(fn: Callable): From f21b6dc5e82447b810dba86406622a576aa51e7f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Mon, 2 Oct 2023 23:00:14 +0800 Subject: [PATCH 20/20] improve --- examples/TODO/resnet_dynamo.py | 15 +++- requirements_dev.txt | 1 + src/paddlefx/compiler/tvm.py | 129 +++++++++++++++++++++++++++++++-- 3 files changed, 137 insertions(+), 8 deletions(-) diff --git a/examples/TODO/resnet_dynamo.py b/examples/TODO/resnet_dynamo.py index 0f7b051..8c7b448 100644 --- a/examples/TODO/resnet_dynamo.py +++ b/examples/TODO/resnet_dynamo.py @@ -11,12 +11,21 @@ from paddlefx.compiler.tvm import TVMCompiler +paddle.seed(1234) + +# logging.getLogger().setLevel(logging.DEBUG) + compiler = TVMCompiler(full_graph=True, print_tabular_mode="rich") -net = resnet18() +net = resnet18(pretrained=True, num_classes=2) optimized_net = paddlefx.optimize(net, backend=compiler) -x = paddle.rand([1, 3, 224, 224]) +x = paddle.rand([1, 3, 224, 224], dtype="float32") out = net(x) res = optimized_net(x) +np.testing.assert_allclose(res.numpy(), out.numpy(), rtol=1e-5, atol=1e-6) + +for _ in range(10): + out = net(x) -np.testing.assert_equal(res.numpy(), out.numpy()) +for _ in range(10): + res = optimized_net(x) diff --git a/requirements_dev.txt b/requirements_dev.txt index 4ef1b5e..2119fcf 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -14,6 +14,7 @@ pre-commit>=3.0.0 tabulate==0.9.0 apache-tvm==0.14.dev214 +xgboost # debug python with paddle opencv-python-headless diff --git a/src/paddlefx/compiler/tvm.py b/src/paddlefx/compiler/tvm.py index 28ed8eb..1e58383 100644 --- a/src/paddlefx/compiler/tvm.py +++ b/src/paddlefx/compiler/tvm.py @@ -17,6 +17,24 @@ from .base import SymbolTable +def auto_scheduler(symbol_table, target): + log_file = f"{hash(tuple(out.name for out in symbol_table.outputs))}.json" + task = auto_scheduler.SearchTask( + func=auto_scheduler.register_workload(lambda: symbol_table.outputs), + target=target, + ) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=10, # change this to 1000 to achieve the best performance + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + # verbose=2, + ) + task.tune(tune_option) + # Apply the best schedule + schedule, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(schedule, args, simple_mode=True)) + + class TVMCompiler(CompilerBase): def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor, paddle.Tensor]): import tvm @@ -24,26 +42,35 @@ def gen_compiled_func(self, symbol_table: SymbolTable[te.Tensor, paddle.Tensor]) from tvm import te device = paddle.device.get_device() + # device = "gpu" if device == "cpu": - target = tvm.target.Target(target="llvm", host="llvm") + target = tvm.target.Target(target="llvm") + dev = tvm.cpu() elif device == "gpu": target = tvm.target.Target(target="cuda", host="llvm") + dev = tvm.cuda() else: raise CompilerError(f"Unsupported device in tvm backend: {device}") + target = tvm.target.Target(target="llvm -mtriple=x86_64-linux-gnu") + schedule = te.create_schedule([out.op for out in symbol_table.outputs]) + + print("building tvm func") tvm_func = tvm.build( schedule, [*symbol_table.inputs, *symbol_table.outputs], target, name="tvm_func", ) + print("builded tvm func") - weights = [tvm.nd.array(p.numpy()) for p in symbol_table.weights] + weights = [tvm.nd.array(p.numpy(), device=dev) for p in symbol_table.weights] def compiled_func(*args): - inputs = [tvm.nd.array(arg.numpy()) for arg in args] + inputs = [tvm.nd.array(arg.numpy(), device=dev) for arg in args] outputs = [ - tvm.nd.empty(out.shape, out.dtype) for out in symbol_table.outputs + tvm.nd.empty(out.shape, out.dtype, device=dev) + for out in symbol_table.outputs ] tvm_func(*inputs, *weights, *outputs) len(outputs) @@ -145,6 +172,7 @@ def compile_call_module( module._stride, module._updated_padding, module._dilation, + module._data_format, ), symbol_table[f"{node.name}_bias"], ) @@ -156,6 +184,86 @@ def compile_call_module( module._updated_padding, module._dilation, ) + elif isinstance(module, nn.BatchNorm2D): + symbol_table.add_weight( + f"{node.name}_weight", + ( + te.placeholder( + module.weight.shape, + paddle_dtype_to_str(module.weight.dtype), + name=f"params_{node.name}_weight", + ), + module.weight, + ), + ) + symbol_table.add_weight( + f"{node.name}_bias", + ( + te.placeholder( + module.bias.shape, + paddle_dtype_to_str(module.bias.dtype), + name=f"params_{node.name}_bias", + ), + module.bias, + ), + ) + symbol_table.add_weight( + f"{node.name}_mean", + ( + te.placeholder( + module._mean.shape, + paddle_dtype_to_str(module._mean.dtype), + name=f"params_{node.name}_mean", + ), + module._mean, + ), + ) + symbol_table.add_weight( + f"{node.name}_variance", + ( + te.placeholder( + module._variance.shape, + paddle_dtype_to_str(module._variance.dtype), + name=f"params_{node.name}_variance", + ), + module._variance, + ), + ) + symbol_table[node.name] = topi.nn.batch_norm( + symbol_table[str(node.args[0])], + symbol_table[f"{node.name}_weight"], + symbol_table[f"{node.name}_bias"], + symbol_table[f"{node.name}_mean"], + symbol_table[f"{node.name}_variance"], + epsilon=module._epsilon, + training=module.training, + )[0] + elif isinstance(module, nn.ReLU): + symbol_table[node.name] = topi.nn.relu(symbol_table[str(node.args[0])]) # type: ignore + elif isinstance(module, nn.MaxPool2D): + symbol_table[node.name] = topi.nn.pool2d( + symbol_table[str(node.args[0])], + [module.ksize, module.ksize] + if isinstance(module.ksize, int) + else module.ksize, + [module.stride, module.stride] + if isinstance(module.stride, int) + else module.stride, + [1, 1], + [module.padding, module.padding, module.padding, module.padding] + if isinstance(module.padding, int) + else module.padding, + "max", + module.ceil_mode, + ) + elif isinstance(module, nn.AdaptiveAvgPool2D): + symbol_table[node.name] = topi.nn.adaptive_pool( + symbol_table[str(node.args[0])], + module._output_size, + "avg", + ) + elif isinstance(module, nn.AvgPool2D): + pass else: raise CompilerError(f"Unsupported module: {module.__class__.__name__}") @@ -172,6 +280,7 @@ def compile_call_function( map_ops_to_tvm = { "add": topi.add, + "iadd": topi.add, "sub": topi.subtract, "subtract": topi.subtract, "mul": topi.multiply, @@ -185,8 +294,18 @@ def compile_call_function( if target_name in map_ops_to_tvm.keys(): symbol_args = [symbol_table[str(arg)] for arg in node.args] symbol_table[node.name] = map_ops_to_tvm[target_name](*symbol_args) + elif target_name == "flatten": + inp = symbol_table[str(node.args[0])] + batch = inp.shape[0] + from functools import reduce + + shape = [batch, reduce(lambda x, y: x * y, inp.shape[1:])] + symbol_table[node.name] = topi.reshape( + inp, + shape, + ) else: - raise NotImplementedError(f"Unsupported function: {target_name}") + raise CompilerError(f"Unsupported function: {target_name}") def compile_output( self,