From 36c3518efe8fe984420b19f9a4cfe1e1d1787ae3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 3 Jul 2023 08:22:01 +0000 Subject: [PATCH 01/12] fix renset and resnetv2 --- sot/opcode_translator/executor/opcode_executor.py | 10 ++++++++-- tests/run_all_paddle_ci.sh | 2 -- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index a2990c3f2..e5fbfb90c 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -705,7 +705,6 @@ def ROT_FOUR(self, instr: Instruction): BINARY_OR = tos_op_wrapper(operator.or_) BINARY_XOR = tos_op_wrapper(operator.xor) - @call_break_graph_decorator(push_n=1) def BINARY_SUBSCR(self, instr: Instruction): key = self.pop() container = self.pop() @@ -736,7 +735,6 @@ def BINARY_SUBSCR(self, instr: Instruction): def NOP(self, instr: Instruction): pass - @call_break_graph_decorator(push_n=1) def LOAD_ATTR(self, instr: Instruction): attr_name = instr.argval obj = self.pop() @@ -1768,6 +1766,14 @@ def CALL_FUNCTION_KW(self, instr: Instruction): def CALL_FUNCTION_EX(self, instr: Instruction): super().CALL_FUNCTION_EX(instr) + @call_break_graph_decorator(push_n=1) + def LOAD_ATTR(self, instr: Instruction): + super().LOAD_ATTR(instr) + + @call_break_graph_decorator(push_n=1) + def BINARY_SUBSCR(self, instr: Instruction): + super().BINARY_SUBSCR(instr) + def RETURN_VALUE(self, instr: Instruction): assert ( len(self._stack) == 1 diff --git a/tests/run_all_paddle_ci.sh b/tests/run_all_paddle_ci.sh index 41fa1c927..0b783fb68 100644 --- a/tests/run_all_paddle_ci.sh +++ b/tests/run_all_paddle_ci.sh @@ -10,8 +10,6 @@ disabled_tests=( ${PADDLE_TEST_BASE}/test_list.py # side effect ${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle ${PADDLE_TEST_BASE}/test_reinforcement_learning.py # 'CartPoleEnv' object has no attribute 'seed' - ${PADDLE_TEST_BASE}/test_resnet_v2.py # segment error: oneDNN - ${PADDLE_TEST_BASE}/test_resnet.py # segment error: oneDNN # tmp = x # for i in range(x) # tmp += Linear(x) From f71ca62a23ea9b517b0e7c30937c25baa9919cc5 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 3 Jul 2023 14:51:33 +0000 Subject: [PATCH 02/12] fix some error --- .../executor/opcode_executor.py | 10 ++++ .../executor/opcode_inline_executor.py | 3 ++ .../executor/variable_dispatch.py | 46 ++++++++++++++++++- tests/test_14_operators.py | 12 +++++ 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index e5fbfb90c..3b3d82431 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -40,6 +40,12 @@ GlobalTracker, LocalTracker, ) +from .variable_dispatch import ( + operator_BAD, + operator_exception_match, + operator_in, + operator_not_in, +) from .variables import ( BuiltinVariable, CallableVariable, @@ -83,6 +89,10 @@ "!=": operator.ne, "is not": operator.is_not, "is": operator.is_, + "in": operator_in, + "not in": operator_not_in, + "exception match": operator_exception_match, + "BAD": operator_BAD, } diff --git a/sot/opcode_translator/executor/opcode_inline_executor.py b/sot/opcode_translator/executor/opcode_inline_executor.py index e6b7fceee..25cb85650 100644 --- a/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/sot/opcode_translator/executor/opcode_inline_executor.py @@ -173,6 +173,9 @@ def inline_call(self): return self.return_value def RETURN_VALUE(self, instr): + assert ( + len(self._stack) == 1 + ), f"Stack must have one element, but get {len(self._stack)} elements." self.return_value = self.pop() return Stop() diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index d986b4ad5..6c63c71dc 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -25,6 +25,47 @@ ) +# just a function for operator.in +def operator_in(left, right): + return left in right + + +def operator_not_in(left, right): + return left not in right + + +def operator_exception_match(left, right): + pass + + +def operator_BAD(left, right): + pass + + +# dict +Dispatcher.register( + operator_in, + ("VariableBase", "VariableBase"), + {}, + lambda left, right: VariableFactory.from_value( + left.get_value() in right.get_value(), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +# dict +Dispatcher.register( + operator_not_in, + ("VariableBase", "VariableBase"), + {}, + lambda left, right: VariableFactory.from_value( + left.get_value() not in right.get_value(), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + # dict Dispatcher.register( dict.keys, @@ -32,6 +73,7 @@ {}, lambda var: var.keys(), ) + Dispatcher.register( dict.values, ("DictVariable",), @@ -88,13 +130,13 @@ getattr, ("VariableBase", "str", "VariableBase"), {}, - lambda var, name: var.getattr(name), + lambda var, name, default: var.getattr(name, default), ) Dispatcher.register( getattr, ("VariableBase", "ConstantVariable"), {}, - lambda var, name, default: var.getattr(name.get_value(), default), + lambda var, name: var.getattr(name.get_value()), ) Dispatcher.register( getattr, diff --git a/tests/test_14_operators.py b/tests/test_14_operators.py index c192b40fa..1dcdffc1c 100644 --- a/tests/test_14_operators.py +++ b/tests/test_14_operators.py @@ -255,6 +255,14 @@ def operator_is_(x: paddle.Tensor, y: paddle.Tensor): return (operator.is_(x, x), operator.is_(x, y)) +def operator_in_(x: int, y: list): + return x in y + + +def operator_not_in_(x: int, y: list): + return x not in y + + def operator_is_not(x: paddle.Tensor, y: paddle.Tensor): return (operator.is_not(x, x), operator.is_not(x, y)) @@ -317,6 +325,10 @@ def test_operator_simple(self): operator_is_not, paddle.to_tensor(2), paddle.to_tensor(3) ) self.assert_results(operator_pos, 1) + self.assert_results(operator_in_, 12, [1, 2, 12]) + self.assert_results(operator_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) def test_operator_list(self): self.assert_results(list_getitem, 1, paddle.to_tensor(2)) From 771b04ba9b6201263ca13f5404af0cc3b0daf3c3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 3 Jul 2023 15:09:23 +0000 Subject: [PATCH 03/12] handle dictcomp and genexpr --- .../executor/opcode_inline_executor.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_inline_executor.py b/sot/opcode_translator/executor/opcode_inline_executor.py index 2653cd6e3..f2b4cf99a 100644 --- a/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/sot/opcode_translator/executor/opcode_inline_executor.py @@ -91,6 +91,18 @@ def __init__(self, fn_variable, *args, **kwargs): self._prepare_closure() # TODO: consider generator. + def _handle_comps(self): + is_comp = any( + x in self._fn_value.__name__ + for x in ['', '', ''] + ) + if not is_comp: + return + pattern = r'implicit\d+' + for name in list(self._locals.keys()): + if re.match(pattern, name): + self._locals[name.replace('implicit', '.')] = self._locals[name] + def _prepare_locals(self, *args, **kwargs): from .variables import VariableBase, VariableFactory @@ -116,13 +128,7 @@ def _prepare_locals(self, *args, **kwargs): value = VariableFactory.from_value(value, self._graph, tracker) self._locals[name] = value - if '' in self._fn_value.__name__: - pattern = r'implicit\d+' - for name in list(self._locals.keys()): - if re.match(pattern, name): - self._locals[name.replace('implicit', '.')] = self._locals[ - name - ] + self._handle_comps() log( 5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals From 65d294d1242388adad96bfe64fea2f18d5574ba0 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 3 Jul 2023 16:13:03 +0000 Subject: [PATCH 04/12] fix object call in call_function --- .../executor/opcode_executor.py | 11 --- .../executor/variables/base.py | 3 + tests/test_call_object.py | 69 +++++++++++++++++++ 3 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 tests/test_call_object.py diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 3b3d82431..85cfad3b9 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -48,7 +48,6 @@ ) from .variables import ( BuiltinVariable, - CallableVariable, CellVariable, ConstantVariable, ContainerVariable, @@ -998,8 +997,6 @@ def CALL_FUNCTION(self, instr: Instruction): args = self.pop_n(n_args) kwargs = {} fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException(f"CALL_FUNCTION: {fn} is not callable") ret = fn(*args, **kwargs) self.push(ret) @@ -1022,10 +1019,6 @@ def CALL_FUNCTION_KW(self, instr: Instruction): kwargs = dict(zip(kwargs_keys, kwargs_values)) fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_KW: {fn} is not callable." - ) ret = fn(*args, **kwargs) self.push(ret) @@ -1043,10 +1036,6 @@ def CALL_FUNCTION_EX(self, instr: Instruction): args = args_variable.get_wrapped_items() fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_EX: {fn} is not callable." - ) ret = fn(*args, **kwargs) self.push(ret) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index a5de240a5..db2b12dda 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -504,6 +504,9 @@ def __call__(self, *args, **kwargs): self.graph, GetAttrTracker(self, '__class__'), ) + # if __call__ is a method, we should add self to arguments. + if inspect.ismethod(self.get_value().__call__): + args = (self,) + args unbound_method = get_unbound_method(self.get_value(), '__call__') if hasattr(unbound_method, "__code__"): fn_var = UserDefinedFunctionVariable( diff --git a/tests/test_call_object.py b/tests/test_call_object.py new file mode 100644 index 000000000..235e6197a --- /dev/null +++ b/tests/test_call_object.py @@ -0,0 +1,69 @@ +import unittest + +from test_case_base import TestCaseBase + +import paddle + +patched = lambda self, x: x * self.a + +patched2 = lambda self, x: x * self.a + 3 + + +class A: + def __init__(self, a): + self.a = a + + def __call__(self, x): + return self.add(x) + + def add(self, x): + return x + self.a + + multi = patched + + +class B: + def __init__(self, a): + self.a = A(a) + + def __call__(self, x, func): + return getattr(self.a, func)(x) + + def self_call(self, x, func): + return getattr(self.a, func)(self.a, x) + + +def foo_1(a, x): + return a(x) + + +def foo_2(a, x): + return a.multi(x) + + +def foo_3(b, x): + return b(x, "multi") + + +def foo_4(b, x): + return b(x, "add") + + +def foo_5(b, x): + return b.self_call(x, "multi") + + +class TestExecutor(TestCaseBase): + def test_simple(self): + c = B(13) + c.a.multi = patched2 + self.assert_results(foo_1, A(13), paddle.to_tensor(2)) + self.assert_results(foo_2, A(13), paddle.to_tensor(2)) + self.assert_results(foo_3, B(13), paddle.to_tensor(2)) + self.assert_results(foo_4, B(13), paddle.to_tensor(2)) + self.assert_results(foo_5, c, paddle.to_tensor(2)) + self.assert_results(foo_4, c, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() From b1b17650b7a5311f0a30b0c6ce9f4f6ddc2c4459 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 02:49:01 +0000 Subject: [PATCH 05/12] fix --- sot/symbolic/interpreter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sot/symbolic/interpreter.py b/sot/symbolic/interpreter.py index c84fa27fb..2247e7510 100644 --- a/sot/symbolic/interpreter.py +++ b/sot/symbolic/interpreter.py @@ -75,7 +75,6 @@ def _set(v, s): state[s.name] = v if len(to_sequence(outs)) != len(to_sequence(stmt.outputs)): - breakpoint() raise InnerError("Number output mismatch, some error happen.") map_if( From c4e895a1ca7d90eb54215d3d42ef6e6cf0a6ad89 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 03:02:52 +0000 Subject: [PATCH 06/12] add contain op --- sot/opcode_translator/executor/opcode_executor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 85cfad3b9..2515a871b 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1162,6 +1162,17 @@ def JUMP_FORWARD(self, instr): def JUMP_ABSOLUTE(self, instr: Instruction): self._lasti = self.indexof(instr.jump_to) + def CONTAINS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.argval == 0 or instr.argval == 1 + right, left = self.pop(), self.pop() + op = "in" if instr.argval == 0 else "not in" + self.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + @jump_break_graph_decorator def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): pred_obj = self.peek() From 0cb83c8e737870804b57093c66a45c7a935d1513 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 09:20:22 +0000 Subject: [PATCH 07/12] add breakpoint --- sot/__init__.py | 5 +-- sot/opcode_translator/breakpoint.py | 40 +++++++++++++++++++ .../executor/opcode_executor.py | 12 ++++-- 3 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 sot/opcode_translator/breakpoint.py diff --git a/sot/__init__.py b/sot/__init__.py index 383cb3017..6e763f53e 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,5 +1,4 @@ +from .opcode_translator.breakpoint import add_breakpoint from .translate import symbolic_translate -__all__ = [ - "symbolic_translate", -] +__all__ = ["symbolic_translate", "add_breakpoint"] diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py new file mode 100644 index 000000000..b74c8f8bb --- /dev/null +++ b/sot/opcode_translator/breakpoint.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass + +from ..utils import Singleton, log + +# this file is a debug utils files for quick debug +# >>> sot.add_breakpoint(file, line) +# >>> sot.remove_breakpoint(file, line) + + +@dataclass +class Breakpoint: + file: str + line: int + + def __hash__(self): + return hash((self.file, self.line)) + + +@Singleton +class BreakpointManager: + def __init__(self): + self.breakpoints = set() + + def add(self, file, line): + log(1, f"add breakpoint at {file}:{line}") + self.breakpoints.add(Breakpoint(file, line)) + + def rm(self, *args, **kwargs): + # interactive use, we use abbreviate + self.breakpoints() + + def hit(self, file, line): + _breakpoint = Breakpoint(file, line) + if _breakpoint in self.breakpoints: + return True + return False + + +def add_breakpoint(file, line): + BreakpointManager().add(file, line) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 2515a871b..749467ab1 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -18,6 +18,7 @@ log, log_do, ) +from ..breakpoint import BreakpointManager from ..instruction_utils import ( Instruction, analysis_inputs, @@ -576,10 +577,13 @@ def step(self, instr: Instruction): raise NotImplementException( f"opcode: {instr.opname} is not supported." ) - log( - 3, - f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n", - ) + log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n" + log(3, log_message) + code_file = self._code.co_filename + code_line = self._current_line + if BreakpointManager().hit(code_file, code_line): + print(log_message) + breakpoint() # breakpoint for debug return getattr(self, instr.opname)(instr) # run single step. def indexof(self, instr: Instruction): From 18198f6c47d7d2fb1334686a13765f7790993bdc Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 11:24:11 +0000 Subject: [PATCH 08/12] xxx --- sot/opcode_translator/breakpoint.py | 43 ++++++++++++++++++- .../executor/opcode_executor.py | 25 ++--------- .../instruction_utils/instruction_utils.py | 2 +- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py index b74c8f8bb..c15002b4f 100644 --- a/sot/opcode_translator/breakpoint.py +++ b/sot/opcode_translator/breakpoint.py @@ -1,6 +1,8 @@ from dataclasses import dataclass +from ..opcode_translator.instruction_utils import instrs_info from ..utils import Singleton, log +from .executor.opcode_executor import OpcodeExecutorBase # this file is a debug utils files for quick debug # >>> sot.add_breakpoint(file, line) @@ -20,6 +22,8 @@ def __hash__(self): class BreakpointManager: def __init__(self): self.breakpoints = set() + self.executors = OpcodeExecutorBase.call_stack + self.active = 0 def add(self, file, line): log(1, f"add breakpoint at {file}:{line}") @@ -35,6 +39,43 @@ def hit(self, file, line): return True return False + def locate(self, exe): + for i, _e in enumerate(self.executors): + if _e is exe: + self.activate = i + return + raise RuntimeError("Not found executor.") + + def up(self): + if self.activate == 0: + return + self.activate -= 1 + + def down(self): + if self.activate >= len(self.executors): + return + self.activate += 1 + + def where(self): + """ + display all inline calls. + """ + pass + + def dis(self, range=5): + """ + display all instruction code and source code. + """ + print("displaying debug info...") + cur_exe = self.executors[self.activate] + lines = instrs_info(cur_exe._instructions) + lasti = cur_exe._lasti + print("\n".join(lines[max(lasti - range, 0) : lasti + range + 1])) + # cur_exe._code = dis.dis() + def add_breakpoint(file, line): - BreakpointManager().add(file, line) + BM.add(file, line) + + +BM = BreakpointManager() diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 88440933d..6a5de525c 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -18,13 +18,7 @@ log, log_do, ) -from ..breakpoint import BreakpointManager -from ..instruction_utils import ( - Instruction, - analysis_inputs, - get_instructions, - instrs_info, -) +from ..instruction_utils import Instruction, analysis_inputs, get_instructions from .function_graph import FunctionGraph from .guard import Guard from .instr_flag import FORMAT_VALUE_FLAG as FV @@ -426,13 +420,6 @@ def __init__(self, code: types.CodeType, graph: FunctionGraph): self._name = "Executor" self._prepare_virtual_env() - def print_instrs(self): - """ - Prints the instructions in the executor. - - """ - print(instrs_info(self._instructions)) - def print_sir(self): """ Prints the Static Instruction Representation (SIR) in the executor. @@ -581,7 +568,10 @@ def step(self, instr: Instruction): log(3, log_message) code_file = self._code.co_filename code_line = self._current_line + from ..breakpoint import BreakpointManager + if BreakpointManager().hit(code_file, code_line): + BreakpointManager().locate(self) print(log_message) breakpoint() # breakpoint for debug return getattr(self, instr.opname)(instr) # run single step. @@ -1327,13 +1317,6 @@ def DICT_MERGE(self, instr: Instruction): self._stack[-instr.arg], dict_value ) - def LIST_APPEND(self, instr: Instruction): - list_value = self.pop() - assert instr.argval > 0 - BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( - self._stack[-instr.arg], list_value - ) - def LIST_EXTEND(self, instr: Instruction): list_value = self.pop() assert instr.argval > 0 diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 42011e6ad..bd2d92ecd 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -274,4 +274,4 @@ def instrs_info(instrs): argval=f"({instr.argval})" if instr.argval else "", ) ) - return "\n".join(ret) + return ret From 075f6da3a362bdec70c2d69c2b963f64f80a6a52 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 15:34:07 +0000 Subject: [PATCH 09/12] add event capture to get all exception location. --- sot/__init__.py | 4 +- sot/opcode_translator/breakpoint.py | 94 ++++++++++++++++--- .../executor/opcode_executor.py | 14 +-- .../instruction_utils/instruction_utils.py | 13 ++- sot/utils/exceptions.py | 6 +- 5 files changed, 105 insertions(+), 26 deletions(-) diff --git a/sot/__init__.py b/sot/__init__.py index 6e763f53e..1b0184292 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,4 +1,4 @@ -from .opcode_translator.breakpoint import add_breakpoint +from .opcode_translator.breakpoint import add_breakpoint, add_event from .translate import symbolic_translate -__all__ = ["symbolic_translate", "add_breakpoint"] +__all__ = ["symbolic_translate", "add_breakpoint", "add_event"] diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py index c15002b4f..84b4f8c94 100644 --- a/sot/opcode_translator/breakpoint.py +++ b/sot/opcode_translator/breakpoint.py @@ -1,3 +1,4 @@ +import inspect from dataclasses import dataclass from ..opcode_translator.instruction_utils import instrs_info @@ -24,14 +25,33 @@ def __init__(self): self.breakpoints = set() self.executors = OpcodeExecutorBase.call_stack self.active = 0 + self.record_event = [] + + def clear_event(self, event): + self.record_event.clear() + + def add_event(self, event): + """ + event in ['All' ,'NotImplementException', 'BreakGraphError', 'InnerError'] + """ + self.record_event.append(event) def add(self, file, line): - log(1, f"add breakpoint at {file}:{line}") + log(1, f"add breakpoint at {file}:{line}\n") self.breakpoints.add(Breakpoint(file, line)) - def rm(self, *args, **kwargs): - # interactive use, we use abbreviate - self.breakpoints() + def addn(self, *lines): + """ + called inside a executor. add a list of line number in current file. + """ + if not isinstance(lines, (list, tuple)): + lines = [lines] + for line in lines: + file = self.cur_exe._code.co_filename + self.add(file, line) + + def clear(self): + self.breakpoints.clear() def hit(self, file, line): _breakpoint = Breakpoint(file, line) @@ -50,32 +70,84 @@ def up(self): if self.activate == 0: return self.activate -= 1 + print("current function is: ", self.cur_exe._code.co_name) def down(self): - if self.activate >= len(self.executors): + if self.activate >= len(self.executors) - 1: return self.activate += 1 + print("current function is: ", self.cur_exe._code.co_name) - def where(self): + def opcode(self, cur_exe=None): + if cur_exe is None: + cur_exe = self.cur_exe + instr = cur_exe._instructions[cur_exe._lasti - 1] + message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" + return message + + def bt(self): """ - display all inline calls. + display all inline calls: backtrace. """ + for exe in self.executors: + lines, _ = inspect.getsourcelines(exe._code) + print( + " " + + exe._code.co_filename + + f"({exe._current_line})" + + f"{exe._code.co_name}()" + ) + print(f"-> {lines[0].strip()}") + print(f"-> {self._current_opcode(exe)}") pass + def on_event(self, event): + if "All" in self.record_event or event in self.record_event: + print("event captured.") + self.activate = len(self.executors) - 1 + breakpoint() + + def _dis_source_code(self): + cur_exe = self.executors[self.activate] + lines, start_line = inspect.getsourcelines(cur_exe._code) + cur_line = cur_exe._current_line + lines[ + cur_line - start_line + 1 : cur_line - start_line + 1 + ] = " ^^^^^ HERE \n" + print("\033[31mSource Code is: \033[0m") + print("".join(lines)) + def dis(self, range=5): """ display all instruction code and source code. """ print("displaying debug info...") - cur_exe = self.executors[self.activate] - lines = instrs_info(cur_exe._instructions) + cur_exe = self.cur_exe + print(f"{cur_exe._code}") lasti = cur_exe._lasti - print("\n".join(lines[max(lasti - range, 0) : lasti + range + 1])) - # cur_exe._code = dis.dis() + lines = instrs_info(cur_exe._instructions, lasti - 1, range) + print("\n".join(lines)) + print(self._dis_source_code()) + + @property + def cur_exe(self): + exe = self.executors[self.activate] + return exe + + def sir(self): + """ + display sir in a page. + """ + print("displaying sir...") + self.cur_exe.print_sir() def add_breakpoint(file, line): BM.add(file, line) +def add_event(event): + BM.add_event(event) + + BM = BreakpointManager() diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index a42976085..6e471ce31 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -147,10 +147,7 @@ def impl( for code, guard_fn in guarded_fns: try: if guard_fn(frame): - log( - 3, - f"[Cache]: Cache hit, Guard is {guard_fn.expr if hasattr(guard_fn, 'expr') else 'None'}\n", - ) + log(3, "[Cache]: Cache hit\n") return CustomCode(code, False) except Exception as e: log(3, f"[Cache]: Guard function error: {e}\n") @@ -484,6 +481,8 @@ def get_var(self, name: str): return self._globals[name] elif name in self._builtins.keys(): return self._builtins[name] + elif name in self._cells.keys(): # in closure + return self._cells[name].get_value() else: raise InnerError(f'Can not get var: {name}') @@ -803,9 +802,6 @@ def STORE_FAST(self, instr: Instruction): """ var = self.pop() var.debug_name = instr.argval - if instr.argval == "__breakpoint__": - print(var.value) - breakpoint() self._locals[instr.argval] = var def STORE_GLOBAL(self, instr: Instruction): @@ -1745,8 +1741,6 @@ def FOR_ITER(self, instr): "Found RETURN_VALUE in for loop body." ) - self._graph.add_global_guarded_variable(iterator) - # TODO need support TensorIterVariable.next try: if not isinstance( @@ -1756,7 +1750,7 @@ def FOR_ITER(self, instr): backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) self._lasti = self.indexof(instr.jump_to) - except BreakGraphError as e: + except BreakGraphError: if backup_iter_idx: iterator.idx = backup_iter_idx self._break_graph_in_for_loop(iterator, instr) diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 8e5850024..bd8074765 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -257,9 +257,16 @@ def replace_instr(instructions, instr, new_instr): instructions[idx, idx + 1] = new_instr -def instrs_info(instrs, mark=None): +def instrs_info(instrs, mark=None, range=None): ret = [] + start = -1 + end = 1000000 + if mark is not None and range is not None: + start = mark - range + end = mark + range + 1 for idx, instr in enumerate(instrs): + if idx < start or idx >= end: + continue if instr.starts_line is not None: ret.append("") ret.append( @@ -272,7 +279,9 @@ def instrs_info(instrs, mark=None): opname=instr.opname, arg=str(instr.arg) if instr.arg is not None else "", argval=f"({instr.argval})" if instr.argval else "", - mark=" <--- HERE" if mark == idx else "", + mark="", ) ) + if idx == mark: + ret[-1] = "\033[31m" + ret[-1] + "\033[0m" return ret diff --git a/sot/utils/exceptions.py b/sot/utils/exceptions.py index f5d55bd6a..19d8146ae 100644 --- a/sot/utils/exceptions.py +++ b/sot/utils/exceptions.py @@ -1,5 +1,9 @@ class FallbackErrorBase(Exception): - pass + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from ..opcode_translator.breakpoint import BreakpointManager + + BreakpointManager().on_event(f"{self.__class__.__name__}") class InnerError(FallbackErrorBase): From a39c244f4739802997a480f055ad8eb6150cc688 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 17:15:32 +0000 Subject: [PATCH 10/12] add skip function interface --- sot/__init__.py | 11 +++++++++-- sot/opcode_translator/breakpoint.py | 10 +++++++++- sot/opcode_translator/skip_files.py | 15 +++++++++++++++ sot/opcode_translator/transform.py | 4 ++-- sot/utils/exceptions.py | 7 +++++++ 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/sot/__init__.py b/sot/__init__.py index 1b0184292..49c814b5b 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,4 +1,11 @@ -from .opcode_translator.breakpoint import add_breakpoint, add_event +from .opcode_translator.breakpoint import BM, add_breakpoint, add_event +from .opcode_translator.skip_files import skip_function from .translate import symbolic_translate -__all__ = ["symbolic_translate", "add_breakpoint", "add_event"] +__all__ = [ + "symbolic_translate", + "add_breakpoint", + "add_event", + "BM", + "skip_function", +] diff --git a/sot/opcode_translator/breakpoint.py b/sot/opcode_translator/breakpoint.py index 84b4f8c94..ebc25d002 100644 --- a/sot/opcode_translator/breakpoint.py +++ b/sot/opcode_translator/breakpoint.py @@ -1,4 +1,5 @@ import inspect +import traceback from dataclasses import dataclass from ..opcode_translator.instruction_utils import instrs_info @@ -98,7 +99,7 @@ def bt(self): + f"{exe._code.co_name}()" ) print(f"-> {lines[0].strip()}") - print(f"-> {self._current_opcode(exe)}") + print(f"-> {self.opcode(exe)}") pass def on_event(self, event): @@ -141,6 +142,13 @@ def sir(self): print("displaying sir...") self.cur_exe.print_sir() + def pe(self, e): + """ + print exception. + """ + lines = traceback.format_tb(e.__traceback__) + print("".join(lines)) + def add_breakpoint(file, line): BM.add(file, line) diff --git a/sot/opcode_translator/skip_files.py b/sot/opcode_translator/skip_files.py index 6f2d7cce8..ecbac3514 100644 --- a/sot/opcode_translator/skip_files.py +++ b/sot/opcode_translator/skip_files.py @@ -38,6 +38,8 @@ import google.protobuf import numpy +from ..utils import log + def _strip_init_py(s): return re.sub(r"__init__.py$", "", s) @@ -105,6 +107,8 @@ def _module_dir(m: types.ModuleType): f"^({'|'.join(map(re.escape, skip_file_names))})" ) +customed_skip_code = set() + def need_skip_path(filepath: str) -> bool: """ @@ -119,3 +123,14 @@ def need_skip_path(filepath: str) -> bool: if not filepath.startswith("<"): filepath = os.path.abspath(filepath) return bool(skip_file_name_re.match(filepath)) + + +def skip_function(function): + customed_skip_code.add(function.__code__) + + +def need_skip(pycode): + if pycode in customed_skip_code: + log(3, f"Skip frame by code: {pycode}") + return True + return need_skip_path(pycode.co_filename) diff --git a/sot/opcode_translator/transform.py b/sot/opcode_translator/transform.py index 7f3776453..65a7aff07 100644 --- a/sot/opcode_translator/transform.py +++ b/sot/opcode_translator/transform.py @@ -2,7 +2,7 @@ from ..utils import log, log_do from .executor.opcode_executor import InstructionTranslatorCache -from .skip_files import need_skip_path +from .skip_files import need_skip def eval_frame_callback(frame, **kwargs): @@ -10,7 +10,7 @@ def eval_frame_callback(frame, **kwargs): if frame.f_code.co_flags & 0x20 > 0: return None - if not need_skip_path(frame.f_code.co_filename): + if not need_skip(frame.f_code): log( 2, "[eval_frame_callback] start to translate: " diff --git a/sot/utils/exceptions.py b/sot/utils/exceptions.py index 19d8146ae..20f660fe3 100644 --- a/sot/utils/exceptions.py +++ b/sot/utils/exceptions.py @@ -1,3 +1,6 @@ +import traceback + + class FallbackErrorBase(Exception): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -5,6 +8,10 @@ def __init__(self, *args, **kwargs): BreakpointManager().on_event(f"{self.__class__.__name__}") + def print(self): + lines = traceback.format_tb(self.__traceback__) + print("".join(lines)) + class InnerError(FallbackErrorBase): pass From eb39375eb66d6962d4e700e72665298a99a624ae Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 5 Jul 2023 07:11:30 +0000 Subject: [PATCH 11/12] fix segment segment error. --- sot/symbolic/compile_cache.py | 12 ++++---- tests/test_segment_linear.py | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 tests/test_segment_linear.py diff --git a/sot/symbolic/compile_cache.py b/sot/symbolic/compile_cache.py index 51569c5b1..9b2897925 100644 --- a/sot/symbolic/compile_cache.py +++ b/sot/symbolic/compile_cache.py @@ -10,10 +10,11 @@ def clear_eager_tensor_name(output_tensors): class FallbackWrapper: - def __init__(self, compile_sir): - self.compile_sir = compile_sir + def __init__(self, compiled_fn, SIR): + self.compiled_fn = compiled_fn self.partial_program = None self.concrete_program = None + self.SIR = SIR # for debug def __call__(self, *args, **kwargs): """TODO: we disable partial_program cache here because some bugs in ast to_static. @@ -29,11 +30,11 @@ def __call__(self, *args, **kwargs): # TODO(xiongkun): or True is on purpose, we should remove it later after # dy2static bug is fixed. if self.partial_program is None or True: - outputs = self.compile_sir(*args, **kwargs) + outputs = self.compiled_fn(*args, **kwargs) ( self.concrete_program, self.partial_program, - ) = self.compile_sir.get_concrete_program(*args, **kwargs) + ) = self.compiled_fn.get_concrete_program(*args, **kwargs) else: # Speed up Resnet from 0.0068 --> 0.0057 outputs = self.partial_program(*args, **kwargs) @@ -64,5 +65,6 @@ def value_fn(self, context, sir_name, build_strategy): compile_sir(context, sir_name), build_strategy=build_strategy, enable_fallback=False, - ) + ), + context.get_sir(sir_name), ) diff --git a/tests/test_segment_linear.py b/tests/test_segment_linear.py new file mode 100644 index 000000000..79034c434 --- /dev/null +++ b/tests/test_segment_linear.py @@ -0,0 +1,55 @@ +import unittest + +from test_case_base import TestCaseBase + +import paddle +import sot +from paddle import nn + + +class Head(nn.Layer): + def __init__(self): + super().__init__() + self.head = nn.Linear(10, 150) + + def forward(self, x, patch_embed_size): + masks = self.head(x) + # [b, (h w), c] -> [b, c, h, w] + h, w = patch_embed_size[0], patch_embed_size[1] + masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) + masks = masks.transpose((0, 3, 1, 2)) + return masks + + +class SimpleNet(nn.Layer): + def __init__(self): + super().__init__() + self.tmp = nn.Linear(1, 1024 * 10) + self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) + self.head = Head() + + def getshape(self, x): + x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) + x = paddle.shape(x) + return x + + def forward(self, x): + shape = self.getshape(x) + feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) + logits = self.head(feat, shape[2:]) + return logits + + +class TestExecutor(TestCaseBase): + def test_simple(self): + sot.skip_function(SimpleNet.forward) + x = paddle.randn((1, 8, 8)) + net = SimpleNet() + net = paddle.jit.to_static(net) + loss = net(x) + loss = loss.sum() + loss.backward() + + +if __name__ == "__main__": + unittest.main() From 0e6c80cf4fc6370c2df80384164d64780f4a1a64 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 5 Jul 2023 07:27:11 +0000 Subject: [PATCH 12/12] fix --- sot/opcode_translator/executor/opcode_executor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 6e471ce31..f9d1d70ab 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -147,7 +147,10 @@ def impl( for code, guard_fn in guarded_fns: try: if guard_fn(frame): - log(3, "[Cache]: Cache hit\n") + log( + 3, + f"[Cache]: Cache hit, Guard is {guard_fn.expr if hasattr(guard_fn, 'expr') else 'None'}\n", + ) return CustomCode(code, False) except Exception as e: log(3, f"[Cache]: Guard function error: {e}\n") @@ -1319,6 +1322,13 @@ def DICT_MERGE(self, instr: Instruction): self._stack[-instr.arg], dict_value ) + def LIST_APPEND(self, instr: Instruction): + list_value = self.pop() + assert instr.argval > 0 + BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( + self._stack[-instr.arg], list_value + ) + def LIST_EXTEND(self, instr: Instruction): list_value = self.pop() assert instr.argval > 0 @@ -1741,6 +1751,7 @@ def FOR_ITER(self, instr): "Found RETURN_VALUE in for loop body." ) + self._graph.add_global_guarded_variable(iterator) # TODO need support TensorIterVariable.next try: if not isinstance( @@ -1750,7 +1761,7 @@ def FOR_ITER(self, instr): backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) self._lasti = self.indexof(instr.jump_to) - except BreakGraphError: + except BreakGraphError as e: if backup_iter_idx: iterator.idx = backup_iter_idx self._break_graph_in_for_loop(iterator, instr)