Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[ BugFix ] Fix PaddleDetection bugs (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 authored Jul 4, 2023
1 parent 019b846 commit 90da495
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 20 deletions.
32 changes: 21 additions & 11 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@
GlobalTracker,
LocalTracker,
)
from .variable_dispatch import (
operator_BAD,
operator_exception_match,
operator_in,
operator_not_in,
)
from .variables import (
BuiltinVariable,
CallableVariable,
CellVariable,
ConstantVariable,
ContainerVariable,
Expand Down Expand Up @@ -83,6 +88,10 @@
"!=": operator.ne,
"is not": operator.is_not,
"is": operator.is_,
"in": operator_in,
"not in": operator_not_in,
"exception match": operator_exception_match,
"BAD": operator_BAD,
}


Expand Down Expand Up @@ -988,8 +997,6 @@ def CALL_FUNCTION(self, instr: Instruction):
args = self.pop_n(n_args)
kwargs = {}
fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(f"CALL_FUNCTION: {fn} is not callable")
ret = fn(*args, **kwargs)
self.push(ret)

Expand All @@ -1012,10 +1019,6 @@ def CALL_FUNCTION_KW(self, instr: Instruction):
kwargs = dict(zip(kwargs_keys, kwargs_values))

fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(
f"CALL_FUNCTION_KW: {fn} is not callable."
)
ret = fn(*args, **kwargs)
self.push(ret)

Expand All @@ -1033,10 +1036,6 @@ def CALL_FUNCTION_EX(self, instr: Instruction):
args = args_variable.get_wrapped_items()

fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(
f"CALL_FUNCTION_EX: {fn} is not callable."
)
ret = fn(*args, **kwargs)
self.push(ret)

Expand Down Expand Up @@ -1163,6 +1162,17 @@ def JUMP_FORWARD(self, instr):
def JUMP_ABSOLUTE(self, instr: Instruction):
self._lasti = self.indexof(instr.jump_to)

def CONTAINS_OP(self, instr: Instruction):
# It will only be 0 or 1
assert instr.argval == 0 or instr.argval == 1
right, left = self.pop(), self.pop()
op = "in" if instr.argval == 0 else "not in"
self.push(
BuiltinVariable(
SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker()
)(left, right)
)

@jump_break_graph_decorator
def JUMP_IF_FALSE_OR_POP(self, instr: Instruction):
pred_obj = self.peek()
Expand Down
23 changes: 16 additions & 7 deletions sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def __init__(self, fn_variable, *args, **kwargs):
self._prepare_closure()
# TODO: consider generator.

def _handle_comps(self):
is_comp = any(
x in self._fn_value.__name__
for x in ['<listcomp>', '<dictcomp>', '<genexpr>']
)
if not is_comp:
return
pattern = r'implicit\d+'
for name in list(self._locals.keys()):
if re.match(pattern, name):
self._locals[name.replace('implicit', '.')] = self._locals[name]

def _prepare_locals(self, *args, **kwargs):
from .variables import VariableBase, VariableFactory

Expand All @@ -116,13 +128,7 @@ def _prepare_locals(self, *args, **kwargs):
value = VariableFactory.from_value(value, self._graph, tracker)
self._locals[name] = value

if '<listcomp>' in self._fn_value.__name__:
pattern = r'implicit\d+'
for name in list(self._locals.keys()):
if re.match(pattern, name):
self._locals[name.replace('implicit', '.')] = self._locals[
name
]
self._handle_comps()

log(
5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals
Expand Down Expand Up @@ -182,6 +188,9 @@ def inline_call(self):
return self.return_value

def RETURN_VALUE(self, instr):
assert (
len(self._stack) == 1
), f"Stack must have one element, but get {len(self._stack)} elements."
self.return_value = self.pop()
return Stop()

Expand Down
44 changes: 43 additions & 1 deletion sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,55 @@
)


# just a function for operator.in
def operator_in(left, right):
return left in right


def operator_not_in(left, right):
return left not in right


def operator_exception_match(left, right):
pass


def operator_BAD(left, right):
pass


# dict
Dispatcher.register(
operator_in,
("VariableBase", "VariableBase"),
{},
lambda left, right: VariableFactory.from_value(
left.get_value() in right.get_value(),
left.graph,
tracker=DummyTracker([left, right]),
),
)

# dict
Dispatcher.register(
operator_not_in,
("VariableBase", "VariableBase"),
{},
lambda left, right: VariableFactory.from_value(
left.get_value() not in right.get_value(),
left.graph,
tracker=DummyTracker([left, right]),
),
)

# dict
Dispatcher.register(
dict.keys,
("DictVariable",),
{},
lambda var: var.keys(),
)

Dispatcher.register(
dict.values,
("DictVariable",),
Expand Down Expand Up @@ -88,7 +130,7 @@
getattr,
("VariableBase", "str", "VariableBase"),
{},
lambda var, name: var.getattr(name),
lambda var, name, default: var.getattr(name, default),
)
Dispatcher.register(
getattr,
Expand Down
3 changes: 3 additions & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ def __call__(self, *args, **kwargs):
self.graph,
GetAttrTracker(self, '__class__'),
)
# if __call__ is a method, we should add self to arguments.
if inspect.ismethod(self.get_value().__call__):
args = (self,) + args
unbound_method = get_unbound_method(self.get_value(), '__call__')
if hasattr(unbound_method, "__code__"):
fn_var = UserDefinedFunctionVariable(
Expand Down
1 change: 0 additions & 1 deletion sot/symbolic/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def _set(v, s):
state[s.name] = v

if len(to_sequence(outs)) != len(to_sequence(stmt.outputs)):
breakpoint()
raise InnerError("Number output mismatch, some error happen.")

map_if(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_14_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def operator_is_(x: paddle.Tensor, y: paddle.Tensor):
return (operator.is_(x, x), operator.is_(x, y))


def operator_in_(x: int, y: list):
return x in y


def operator_not_in_(x: int, y: list):
return x not in y


def operator_is_not(x: paddle.Tensor, y: paddle.Tensor):
return (operator.is_not(x, x), operator.is_not(x, y))

Expand Down Expand Up @@ -317,6 +325,10 @@ def test_operator_simple(self):
operator_is_not, paddle.to_tensor(2), paddle.to_tensor(3)
)
self.assert_results(operator_pos, 1)
self.assert_results(operator_in_, 12, [1, 2, 12])
self.assert_results(operator_in_, 12, [1, 2, 3])
self.assert_results(operator_not_in_, 12, [1, 2, 3])
self.assert_results(operator_not_in_, 12, [1, 2, 3])

def test_operator_list(self):
self.assert_results(list_getitem, 1, paddle.to_tensor(2))
Expand Down
69 changes: 69 additions & 0 deletions tests/test_call_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest

from test_case_base import TestCaseBase

import paddle

patched = lambda self, x: x * self.a

patched2 = lambda self, x: x * self.a + 3


class A:
def __init__(self, a):
self.a = a

def __call__(self, x):
return self.add(x)

def add(self, x):
return x + self.a

multi = patched


class B:
def __init__(self, a):
self.a = A(a)

def __call__(self, x, func):
return getattr(self.a, func)(x)

def self_call(self, x, func):
return getattr(self.a, func)(self.a, x)


def foo_1(a, x):
return a(x)


def foo_2(a, x):
return a.multi(x)


def foo_3(b, x):
return b(x, "multi")


def foo_4(b, x):
return b(x, "add")


def foo_5(b, x):
return b.self_call(x, "multi")


class TestExecutor(TestCaseBase):
def test_simple(self):
c = B(13)
c.a.multi = patched2
self.assert_results(foo_1, A(13), paddle.to_tensor(2))
self.assert_results(foo_2, A(13), paddle.to_tensor(2))
self.assert_results(foo_3, B(13), paddle.to_tensor(2))
self.assert_results(foo_4, B(13), paddle.to_tensor(2))
self.assert_results(foo_5, c, paddle.to_tensor(2))
self.assert_results(foo_4, c, paddle.to_tensor(2))


if __name__ == "__main__":
unittest.main()

0 comments on commit 90da495

Please sign in to comment.