From 3fcab2cab6ba921e762764563a8f99118abd133e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 6 Jul 2023 16:15:33 +0800 Subject: [PATCH] add psdb_print and fix some bugs in detection model. (#250) --- .github/workflows/codestyle.yaml | 1 - .github/workflows/unittest.yaml | 3 -- sot/__init__.py | 2 + sot/opcode_translator/breakpoint.py | 24 +++++----- .../executor/dispatch_functions.py | 32 ++++++++++++++ .../executor/function_graph.py | 19 ++++++++ .../executor/opcode_executor.py | 18 +++++--- .../executor/variable_dispatch.py | 44 +++++++++++-------- .../executor/variables/basic.py | 34 +++++++++++++- .../executor/variables/callable.py | 11 ++++- sot/symbolic/compile_cache.py | 3 ++ sot/utils/__init__.py | 2 + sot/utils/paddle_api_config.py | 1 + sot/utils/utils.py | 7 ++- tests/test_segment_linear.py | 4 +- 15 files changed, 160 insertions(+), 45 deletions(-) create mode 100644 sot/opcode_translator/executor/dispatch_functions.py diff --git a/.github/workflows/codestyle.yaml b/.github/workflows/codestyle.yaml index 6a98747bc..306fc6f00 100644 --- a/.github/workflows/codestyle.yaml +++ b/.github/workflows/codestyle.yaml @@ -23,6 +23,5 @@ jobs: - name: Install dependencies run: | pip install pre-commit - - name: Precommit Check run : pre-commit run --all-files diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index d30db9485..bfd74b5ac 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -28,17 +28,14 @@ jobs: run: | python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html pip install -e ".[dev]" - - name: Run unit tests working-directory: ./tests/ run: | bash run_all.sh - - name: Run examples working-directory: ./examples/ run: | bash run_all.sh - - name: Run xdoctest working-directory: ./ run: | diff --git a/sot/__init__.py b/sot/__init__.py index 49c814b5b..2bd54d405 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,6 +1,7 @@ from .opcode_translator.breakpoint import BM, add_breakpoint, add_event from .opcode_translator.skip_files import skip_function from .translate import symbolic_translate +from .utils import psdb_print __all__ = [ "symbolic_translate", @@ -8,4 +9,5 @@ "add_event", "BM", "skip_function", + "psdb_print", ] diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py index ebc25d002..f353f5dd8 100644 --- a/sot/opcode_translator/breakpoint.py +++ b/sot/opcode_translator/breakpoint.py @@ -15,9 +15,11 @@ class Breakpoint: file: str line: int + co_name: str + offset: int def __hash__(self): - return hash((self.file, self.line)) + return hash((self.file, self.line, self.co_name, self.offset)) @Singleton @@ -37,9 +39,9 @@ def add_event(self, event): """ self.record_event.append(event) - def add(self, file, line): + def add(self, file, line, coname=None, offset=None): log(1, f"add breakpoint at {file}:{line}\n") - self.breakpoints.add(Breakpoint(file, line)) + self.breakpoints.add(Breakpoint(file, line, coname, offset)) def addn(self, *lines): """ @@ -54,9 +56,10 @@ def addn(self, *lines): def clear(self): self.breakpoints.clear() - def hit(self, file, line): - _breakpoint = Breakpoint(file, line) - if _breakpoint in self.breakpoints: + def hit(self, file, line, co_name, offset): + if Breakpoint(file, line, None, None) in self.breakpoints: + return True + if Breakpoint(file, line, co_name, offset) in self.breakpoints: return True return False @@ -124,11 +127,12 @@ def dis(self, range=5): """ print("displaying debug info...") cur_exe = self.cur_exe - print(f"{cur_exe._code}") + print(self._dis_source_code()) + + print(f"\n{cur_exe._code}") lasti = cur_exe._lasti lines = instrs_info(cur_exe._instructions, lasti - 1, range) print("\n".join(lines)) - print(self._dis_source_code()) @property def cur_exe(self): @@ -150,8 +154,8 @@ def pe(self, e): print("".join(lines)) -def add_breakpoint(file, line): - BM.add(file, line) +def add_breakpoint(file, line, co_name=None, offset=None): + BM.add(file, line, co_name, offset) def add_event(event): diff --git a/sot/opcode_translator/executor/dispatch_functions.py b/sot/opcode_translator/executor/dispatch_functions.py new file mode 100644 index 000000000..3886dbe69 --- /dev/null +++ b/sot/opcode_translator/executor/dispatch_functions.py @@ -0,0 +1,32 @@ +# This file stores the customed function that will be called by the dispatch mechanism. + +from ...utils import BreakGraphError, NotImplementException + + +def raise_break_graph_fn(*args, **kwarg): + raise BreakGraphError("raise by raise_break_graph_fn.") + + +def raise_not_implement_fn(*args, **kwarg): + raise NotImplementException("raise by raise_break_graph_fn.") + + +# just a function for operator.in +def operator_in(left, right): + return left in right + + +def operator_not_in(left, right): + return left not in right + + +def operator_exception_match(left, right): + pass + + +def operator_BAD(left, right): + pass + + +def tensor_numel(x): + pass diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 1c401c591..a665286c6 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -68,6 +68,7 @@ class FunctionGraph: "stmt_ir", "global_guards", "side_effects_state", + "print_variables", ], ) @@ -79,8 +80,12 @@ def __init__(self, frame, **kwargs): self.side_effects = SideEffects() self.py_frame = frame self._global_guarded_variables: list[VariableBase] = [] + self._print_variables = [] self.build_strategy = kwargs.get('build_strategy', None) + def add_print_variables(self, variable): + self._print_variables.append(variable) + def need_add_input(self, var): if var.id in self.inner_out: return False @@ -102,6 +107,7 @@ def save_memo(self): stmt_ir=saved_stmt_ir, global_guards=list(self._global_guarded_variables), side_effects_state=self.side_effects.get_state(), + print_variables=list(self._print_variables), ) def restore_memo(self, memo): @@ -110,6 +116,7 @@ def restore_memo(self, memo): self.sir_ctx.replace_TOS(memo.stmt_ir) self._global_guarded_variables = memo.global_guards self.side_effects.restore_state(memo.side_effects_state) + self._print_variables = memo.print_variables def collect_input_variables(self, inputs: list[VariableBase]): for inp in inputs: @@ -176,6 +183,7 @@ def start_compile(self, *ret_vars: VariableBase): # deal side effect self.restore_side_effects(self.side_effects.variables) + self.restore_print_stmts(self._print_variables) tracker_output_path = show_trackers() if tracker_output_path: @@ -305,8 +313,19 @@ def _find_tensor_outputs( var, TensorVariable ): output_tensors.append(var) + # Find Tensor in print_stmts + for print_stmt in self._print_variables: + for var in print_stmt.flatten_items(): + if isinstance(var.tracker, DummyTracker) and isinstance( + var, TensorVariable + ): + output_tensors.append(var) return output_tensors + def restore_print_stmts(self, variables: list[VariableBase]): + for var in variables: + var._reconstruct(self.pycode_gen) + def restore_side_effects(self, variables: list[VariableBase]): if not variables: return diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index caa340254..b0d982a26 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -19,6 +19,12 @@ log_do, ) from ..instruction_utils import Instruction, analysis_inputs, get_instructions +from .dispatch_functions import ( + operator_BAD, + operator_exception_match, + operator_in, + operator_not_in, +) from .function_graph import FunctionGraph from .guard import Guard from .instr_flag import FORMAT_VALUE_FLAG as FV @@ -35,12 +41,6 @@ GlobalTracker, LocalTracker, ) -from .variable_dispatch import ( - operator_BAD, - operator_exception_match, - operator_in, - operator_not_in, -) from .variables import ( BuiltinVariable, CellVariable, @@ -574,9 +574,13 @@ def step(self, instr: Instruction): log(3, log_message) code_file = self._code.co_filename code_line = self._current_line + code_name = self._code.co_name + code_offset = instr.offset from ..breakpoint import BreakpointManager - if BreakpointManager().hit(code_file, code_line): + if BreakpointManager().hit( + code_file, code_line, code_name, code_offset + ): BreakpointManager().locate(self) print(log_message) breakpoint() # breakpoint for debug diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 28aee7829..d3448bcbf 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -12,6 +12,12 @@ UNARY_OPS, magic_method_builtin_dispatch, ) +from .dispatch_functions import ( + operator_in, + operator_not_in, + raise_break_graph_fn, + tensor_numel, +) from .dispatcher import Dispatcher from .tracker import DummyTracker from .variables import VariableBase, VariableFactory @@ -24,24 +30,6 @@ TensorVariable, ) - -# just a function for operator.in -def operator_in(left, right): - return left in right - - -def operator_not_in(left, right): - return left not in right - - -def operator_exception_match(left, right): - pass - - -def operator_BAD(left, right): - pass - - # dict Dispatcher.register( operator_in, @@ -478,11 +466,29 @@ def is_not_func(var: VariableBase, other: VariableBase): ), ) # Tensor +fallback_tensor_unary_method = { + int, + bool, + operator.truth, +} + +Dispatcher.register(tensor_numel, ("TensorVariable",), {}, lambda x: x.numel()) + for unary_fn in UNARY_OPS: # Tensor doesn't support unary +, skip it # TODO(SigureMo): deal len and bool - if unary_fn in {operator.pos, len, bool, operator.truth}: + if unary_fn in {len}: + continue + + if unary_fn in fallback_tensor_unary_method: + Dispatcher.register( + unary_fn, + ("TensorVariable",), + {}, + raise_break_graph_fn, + ) continue + for magic_method in magic_method_builtin_dispatch(unary_fn): Dispatcher.register( unary_fn, diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 3700c188f..37bce98dc 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -19,6 +19,7 @@ paddle_tensor_methods, ) from ....utils.exceptions import InnerError +from ..dispatch_functions import tensor_numel from ..guard import ( StringifyExpression, object_equal_stringify_guard, @@ -144,6 +145,29 @@ def wrap_literal(value: Any, graph: FunctionGraph) -> ConstantVariable: return ConstantVariable(value, graph, ConstTracker(value)) +class PrintStmtVariable(VariableBase): + def __init__(self, value: Any, graph: FunctionGraph): + super().__init__(DanglingTracker()) + self.args, self.kwargs = value + self.graph = graph + + def _reconstruct(self, codegen: PyCodeGen): + # do we need ? may be too strict. + for var in self.args: + self.graph.add_global_guarded_variable(var) + for var in self.kwargs.values(): + self.graph.add_global_guarded_variable(var) + # currently dont' consider kwargs + codegen.gen_load_global("print") + for var in self.args: + var.reconstruct(codegen) + codegen.gen_call_function(len(self.args)) + codegen.gen_pop_top() + + def flatten_items(self): + return self.args + + IMPLEMENTED_TENSOR_PROPERTIES = set() @@ -201,7 +225,10 @@ def __init__( super().__init__(tracker) if isinstance(tensor, paddle.Tensor): self.value = tensor - self.meta = MetaInfo.from_tensor(tensor) + try: + self.meta = MetaInfo.from_tensor(tensor) + except: + breakpoint() elif isinstance(tensor, MetaInfo): self.value = None self.meta = tensor @@ -312,6 +339,7 @@ def size(self): f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" ) elements = reduce(operator.mul, self.meta.shape, 1) + self.graph.add_global_guarded_variable(self) return ConstantVariable.wrap_literal(elements, self.graph) @tensor_property @@ -325,6 +353,9 @@ def shape(self): self.meta.shape, self.graph, tracker=ConstTracker(self.meta.shape) ) + def numel(self): + return self.size + def is_tensor(self): return ConstantVariable.wrap_literal(True, self.graph) @@ -346,6 +377,7 @@ def is_floating_point(self): def getattr(self, name: str): method_name_to_builtin_fn = { "dim": paddle.rank, + "numel": tensor_numel, "ndimension": paddle.rank, "is_tensor": paddle.is_tensor, "is_complex": paddle.is_complex, diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index c6bffc294..954faaf7c 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -16,6 +16,7 @@ is_paddle_api, log_do, magic_method_builtin_dispatch, + psdb_print, ) from ....utils.exceptions import BreakGraphError, FallbackErrorBase from ..dispatcher import Dispatcher @@ -28,7 +29,7 @@ Tracker, ) from .base import VariableBase, VariableFactory -from .basic import ConstantVariable, ObjectVariable +from .basic import ConstantVariable, ObjectVariable, PrintStmtVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -84,11 +85,17 @@ def __init__( def call_function(self, *args, **kwargs) -> VariableBase: from ..opcode_inline_executor import OpcodeInlineExecutor + # special function for inner debug. if self.value is ASSERT: # TODO: add comptime check mechanism return ConstantVariable.wrap_literal( self.value(args[0].value), self.graph ) + if self.value is psdb_print: + self.graph.add_print_variables( + PrintStmtVariable((args, kwargs), self.graph) + ) + return ConstantVariable.wrap_literal(None, self.graph) checkpoint = self.graph.save_memo() try: @@ -292,7 +299,7 @@ def call_function(self, *args, **kwargs): return fn_var(*(self, *args), **kwargs) - @VariableFactory.register_from_value() + @VariableFactory.register_from_value(successor="PaddleApiVariable") def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): if isinstance( value, paddle.nn.Layer diff --git a/sot/symbolic/compile_cache.py b/sot/symbolic/compile_cache.py index 9b2897925..7d1773c9d 100644 --- a/sot/symbolic/compile_cache.py +++ b/sot/symbolic/compile_cache.py @@ -29,6 +29,9 @@ def __call__(self, *args, **kwargs): # TODO(zmh): modify the if # TODO(xiongkun): or True is on purpose, we should remove it later after # dy2static bug is fixed. + log_do( + 2, lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR) + ) if self.partial_program is None or True: outputs = self.compiled_fn(*args, **kwargs) ( diff --git a/sot/utils/__init__.py b/sot/utils/__init__.py index 49712af5d..7d39ff3bc 100644 --- a/sot/utils/__init__.py +++ b/sot/utils/__init__.py @@ -31,6 +31,7 @@ map_if, meta_str, no_eval_frame, + psdb_print, show_trackers, ) @@ -58,6 +59,7 @@ "is_strict_mode", "paddle_tensor_methods", "ASSERT", + "psdb_print", "ResumeFnNameFactory", "list_contain_by_id", "list_find_index_by_id", diff --git a/sot/utils/paddle_api_config.py b/sot/utils/paddle_api_config.py index 3b929c0fc..2c7042816 100644 --- a/sot/utils/paddle_api_config.py +++ b/sot/utils/paddle_api_config.py @@ -18,6 +18,7 @@ def get_paddle_api(): paddle.linalg, paddle.signal, paddle.fft, + paddle.vision.ops, ] special_paddle_apis = [paddle.tensor.fill_constant] non_operator_related_apis = [ diff --git a/sot/utils/utils.py b/sot/utils/utils.py index 59d2ecebb..9c7f98763 100644 --- a/sot/utils/utils.py +++ b/sot/utils/utils.py @@ -52,7 +52,8 @@ def __init__(self) -> None: self.gen = NameGenerator('__resume_fn_') def next(self): - return self.gen.next() + name = self.gen.next() + return name def log(level, *args): @@ -198,6 +199,10 @@ def ASSERT(input: bool): assert input +def psdb_print(*args, **kwargs): + print(*args, **kwargs) + + def list_find_index_by_id(li: list[Any], item: Any) -> int: return [id(it) for it in li].index(id(item)) diff --git a/tests/test_segment_linear.py b/tests/test_segment_linear.py index 79034c434..a44281d80 100644 --- a/tests/test_segment_linear.py +++ b/tests/test_segment_linear.py @@ -45,7 +45,9 @@ def test_simple(self): sot.skip_function(SimpleNet.forward) x = paddle.randn((1, 8, 8)) net = SimpleNet() - net = paddle.jit.to_static(net) + net = paddle.jit.to_static( + net + ) # dont make effect. we need fetch sot PR in paddle. loss = net(x) loss = loss.sum() loss.backward()