Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[Dispatcher] Implement parameterless call support for list and dict (#…
Browse files Browse the repository at this point in the history
…326)

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
jjyaoao and SigureMo authored Aug 11, 2023
1 parent 5aefaf7 commit 3dc6a08
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 13 deletions.
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class Dispatcher:
handlers: dict[
Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]]
] = {}
graph: Any = None

@classmethod
def register(
Expand Down
8 changes: 7 additions & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
operator_in,
operator_not_in,
)
from .dispatcher import Dispatcher
from .function_graph import FunctionGraph
from .guard import Guard
from .instr_flag import FORMAT_VALUE_FLAG as FV
Expand Down Expand Up @@ -268,7 +269,7 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None:
except Exception as e:
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e
finally:
simulator._graph.pycode_gen = None
simulator.cleanup()


def tos_op_wrapper(fn: Callable):
Expand Down Expand Up @@ -1511,6 +1512,11 @@ def __init__(self, frame: types.FrameType, **kwargs):
self._name = "Executor"
self.call_stack[:] = []
super().__init__(frame.f_code, graph)
Dispatcher.graph = graph

def cleanup(self):
self._graph.pycode_gen = None
Dispatcher.graph = None

@event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2)
def _prepare_virtual_env(self):
Expand Down
29 changes: 19 additions & 10 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def inner(*args, **kwargs):
)

# dict
Dispatcher.register(
dict,
(),
lambda: VariableFactory.from_value(
{},
graph=Dispatcher.graph,
tracker=DummyTracker([]),
),
)
Dispatcher.register(
dict.get,
("DictVariable", "ConstantVariable", optional("VariableBase")),
Expand Down Expand Up @@ -196,6 +205,16 @@ def inner(*args, **kwargs):
)

# list
Dispatcher.register(
list,
(),
lambda: VariableFactory.from_value(
[],
graph=Dispatcher.graph,
tracker=DummyTracker([]),
),
)

Dispatcher.register(
list,
("ContainerVariable | EnumerateVariable",),
Expand Down Expand Up @@ -399,16 +418,6 @@ def dispatch_reversed(var: ContainerVariable):
("ContainerVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
bool,
("ConstantVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ContainerVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ConstantVariable",),
Expand Down
13 changes: 12 additions & 1 deletion tests/test_04_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,15 @@ def list_tensor_min_api(x: paddle.Tensor):
return x.min()


class TestExecutor(TestCaseBase):
def list_no_arguments():
l1 = list() # noqa: C408
l1.append(1)
l2 = list() # noqa: C408
l2.append(2)
return l1[0] + l2[0]


class TestList(TestCaseBase):
def test_simple(self):
self.assert_results(list_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2))
Expand Down Expand Up @@ -258,6 +266,9 @@ def test_simple(self):
)
self.assert_results(list_tensor_min_api, paddle.to_tensor([1, 2, 3]))

def test_list_noargs(self):
self.assert_results(list_no_arguments)


if __name__ == "__main__":
unittest.main()
13 changes: 12 additions & 1 deletion tests/test_05_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@ def dict_construct_from_comprehension():
return d


class TestExecutor(TestCaseBase):
def dict_no_arguments():
d1 = dict() # noqa: C408
d1.update({1: 2})
d2 = dict() # noqa: C408
d2.update({3: 4})
return d1[1] + d2[3]


class TestDict(TestCaseBase):
def test_build_map(self):
self.assert_results(build_map, 1, paddle.to_tensor(2))

Expand Down Expand Up @@ -199,6 +207,9 @@ def test_construct(self):
self.assert_results(dict_construct_from_tuple)
self.assert_results(dict_construct_from_comprehension)

def test_dict_noargs(self):
self.assert_results(dict_no_arguments)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3dc6a08

Please sign in to comment.