From 90da4951a8baf0aa4d9ed60531adbb4c31782392 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 4 Jul 2023 16:49:22 +0800 Subject: [PATCH] [ BugFix ] Fix PaddleDetection bugs (#235) --- .../executor/opcode_executor.py | 32 ++++++--- .../executor/opcode_inline_executor.py | 23 +++++-- .../executor/variable_dispatch.py | 44 +++++++++++- .../executor/variables/base.py | 3 + sot/symbolic/interpreter.py | 1 - tests/test_14_operators.py | 12 ++++ tests/test_call_object.py | 69 +++++++++++++++++++ 7 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 tests/test_call_object.py diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index e5fbfb90c..2515a871b 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -40,9 +40,14 @@ GlobalTracker, LocalTracker, ) +from .variable_dispatch import ( + operator_BAD, + operator_exception_match, + operator_in, + operator_not_in, +) from .variables import ( BuiltinVariable, - CallableVariable, CellVariable, ConstantVariable, ContainerVariable, @@ -83,6 +88,10 @@ "!=": operator.ne, "is not": operator.is_not, "is": operator.is_, + "in": operator_in, + "not in": operator_not_in, + "exception match": operator_exception_match, + "BAD": operator_BAD, } @@ -988,8 +997,6 @@ def CALL_FUNCTION(self, instr: Instruction): args = self.pop_n(n_args) kwargs = {} fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException(f"CALL_FUNCTION: {fn} is not callable") ret = fn(*args, **kwargs) self.push(ret) @@ -1012,10 +1019,6 @@ def CALL_FUNCTION_KW(self, instr: Instruction): kwargs = dict(zip(kwargs_keys, kwargs_values)) fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_KW: {fn} is not callable." - ) ret = fn(*args, **kwargs) self.push(ret) @@ -1033,10 +1036,6 @@ def CALL_FUNCTION_EX(self, instr: Instruction): args = args_variable.get_wrapped_items() fn = self.pop() - if not isinstance(fn, CallableVariable): - raise NotImplementException( - f"CALL_FUNCTION_EX: {fn} is not callable." - ) ret = fn(*args, **kwargs) self.push(ret) @@ -1163,6 +1162,17 @@ def JUMP_FORWARD(self, instr): def JUMP_ABSOLUTE(self, instr: Instruction): self._lasti = self.indexof(instr.jump_to) + def CONTAINS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.argval == 0 or instr.argval == 1 + right, left = self.pop(), self.pop() + op = "in" if instr.argval == 0 else "not in" + self.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + @jump_break_graph_decorator def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): pred_obj = self.peek() diff --git a/sot/opcode_translator/executor/opcode_inline_executor.py b/sot/opcode_translator/executor/opcode_inline_executor.py index 441b57ef8..f2b4cf99a 100644 --- a/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/sot/opcode_translator/executor/opcode_inline_executor.py @@ -91,6 +91,18 @@ def __init__(self, fn_variable, *args, **kwargs): self._prepare_closure() # TODO: consider generator. + def _handle_comps(self): + is_comp = any( + x in self._fn_value.__name__ + for x in ['', '', ''] + ) + if not is_comp: + return + pattern = r'implicit\d+' + for name in list(self._locals.keys()): + if re.match(pattern, name): + self._locals[name.replace('implicit', '.')] = self._locals[name] + def _prepare_locals(self, *args, **kwargs): from .variables import VariableBase, VariableFactory @@ -116,13 +128,7 @@ def _prepare_locals(self, *args, **kwargs): value = VariableFactory.from_value(value, self._graph, tracker) self._locals[name] = value - if '' in self._fn_value.__name__: - pattern = r'implicit\d+' - for name in list(self._locals.keys()): - if re.match(pattern, name): - self._locals[name.replace('implicit', '.')] = self._locals[ - name - ] + self._handle_comps() log( 5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals @@ -182,6 +188,9 @@ def inline_call(self): return self.return_value def RETURN_VALUE(self, instr): + assert ( + len(self._stack) == 1 + ), f"Stack must have one element, but get {len(self._stack)} elements." self.return_value = self.pop() return Stop() diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 6c21031ae..6c63c71dc 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -25,6 +25,47 @@ ) +# just a function for operator.in +def operator_in(left, right): + return left in right + + +def operator_not_in(left, right): + return left not in right + + +def operator_exception_match(left, right): + pass + + +def operator_BAD(left, right): + pass + + +# dict +Dispatcher.register( + operator_in, + ("VariableBase", "VariableBase"), + {}, + lambda left, right: VariableFactory.from_value( + left.get_value() in right.get_value(), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +# dict +Dispatcher.register( + operator_not_in, + ("VariableBase", "VariableBase"), + {}, + lambda left, right: VariableFactory.from_value( + left.get_value() not in right.get_value(), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + # dict Dispatcher.register( dict.keys, @@ -32,6 +73,7 @@ {}, lambda var: var.keys(), ) + Dispatcher.register( dict.values, ("DictVariable",), @@ -88,7 +130,7 @@ getattr, ("VariableBase", "str", "VariableBase"), {}, - lambda var, name: var.getattr(name), + lambda var, name, default: var.getattr(name, default), ) Dispatcher.register( getattr, diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index a5de240a5..db2b12dda 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -504,6 +504,9 @@ def __call__(self, *args, **kwargs): self.graph, GetAttrTracker(self, '__class__'), ) + # if __call__ is a method, we should add self to arguments. + if inspect.ismethod(self.get_value().__call__): + args = (self,) + args unbound_method = get_unbound_method(self.get_value(), '__call__') if hasattr(unbound_method, "__code__"): fn_var = UserDefinedFunctionVariable( diff --git a/sot/symbolic/interpreter.py b/sot/symbolic/interpreter.py index c84fa27fb..2247e7510 100644 --- a/sot/symbolic/interpreter.py +++ b/sot/symbolic/interpreter.py @@ -75,7 +75,6 @@ def _set(v, s): state[s.name] = v if len(to_sequence(outs)) != len(to_sequence(stmt.outputs)): - breakpoint() raise InnerError("Number output mismatch, some error happen.") map_if( diff --git a/tests/test_14_operators.py b/tests/test_14_operators.py index c192b40fa..1dcdffc1c 100644 --- a/tests/test_14_operators.py +++ b/tests/test_14_operators.py @@ -255,6 +255,14 @@ def operator_is_(x: paddle.Tensor, y: paddle.Tensor): return (operator.is_(x, x), operator.is_(x, y)) +def operator_in_(x: int, y: list): + return x in y + + +def operator_not_in_(x: int, y: list): + return x not in y + + def operator_is_not(x: paddle.Tensor, y: paddle.Tensor): return (operator.is_not(x, x), operator.is_not(x, y)) @@ -317,6 +325,10 @@ def test_operator_simple(self): operator_is_not, paddle.to_tensor(2), paddle.to_tensor(3) ) self.assert_results(operator_pos, 1) + self.assert_results(operator_in_, 12, [1, 2, 12]) + self.assert_results(operator_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) def test_operator_list(self): self.assert_results(list_getitem, 1, paddle.to_tensor(2)) diff --git a/tests/test_call_object.py b/tests/test_call_object.py new file mode 100644 index 000000000..235e6197a --- /dev/null +++ b/tests/test_call_object.py @@ -0,0 +1,69 @@ +import unittest + +from test_case_base import TestCaseBase + +import paddle + +patched = lambda self, x: x * self.a + +patched2 = lambda self, x: x * self.a + 3 + + +class A: + def __init__(self, a): + self.a = a + + def __call__(self, x): + return self.add(x) + + def add(self, x): + return x + self.a + + multi = patched + + +class B: + def __init__(self, a): + self.a = A(a) + + def __call__(self, x, func): + return getattr(self.a, func)(x) + + def self_call(self, x, func): + return getattr(self.a, func)(self.a, x) + + +def foo_1(a, x): + return a(x) + + +def foo_2(a, x): + return a.multi(x) + + +def foo_3(b, x): + return b(x, "multi") + + +def foo_4(b, x): + return b(x, "add") + + +def foo_5(b, x): + return b.self_call(x, "multi") + + +class TestExecutor(TestCaseBase): + def test_simple(self): + c = B(13) + c.a.multi = patched2 + self.assert_results(foo_1, A(13), paddle.to_tensor(2)) + self.assert_results(foo_2, A(13), paddle.to_tensor(2)) + self.assert_results(foo_3, B(13), paddle.to_tensor(2)) + self.assert_results(foo_4, B(13), paddle.to_tensor(2)) + self.assert_results(foo_5, c, paddle.to_tensor(2)) + self.assert_results(foo_4, c, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main()