diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index dcb32217..324db567 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1478,29 +1478,14 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): def UNPACK_SEQUENCE(self, instr: Instruction): sequence = self.stack.pop() - - ''' - TODO: To unpack iterator - To unpack is easy, just like: - seq = tuple(sequence.get_py_value()) - - But what is the `source` when iterator returned a value ? - ''' - if not isinstance( - sequence, (ListVariable, TupleVariable, TensorVariable) - ): - raise FallbackError(f"Unpack {sequence} is not implemented.") - - assert ( - len(sequence) == instr.arg - ), f"Want unpack {sequence} to {instr.arg}, but the len is {len(sequence)}." - - for i in range(instr.arg - 1, -1, -1): - self.stack.push( - BuiltinVariable( - operator.getitem, self._graph, DanglingTracker() - )(sequence, i) - ) + seq_iter = BuiltinVariable(iter, self._graph, DanglingTracker())( + sequence + ) + unpacked = [] + for _ in range(instr.arg): + unpacked.append(seq_iter.next()) + for item in reversed(unpacked): + self.stack.push(item) def UNPACK_EX(self, instr: Instruction): getitem = BuiltinVariable( diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 20013e3d..2d81163b 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -30,6 +30,7 @@ DictVariable, EnumerateVariable, ListVariable, + MapVariable, NumpyVariable, RangeVariable, SliceVariable, @@ -485,6 +486,19 @@ def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableB ) +# map +Dispatcher.register( + map, + ( + "CallableVariable", + "VariableBase", + ), + lambda fn, var: MapVariable.from_iterator( + fn, var, graph=var.graph, tracker=DummyTracker([var]) + ), +) + + # reversed @Dispatcher.register_decorator(reversed) def dispatch_reversed(var: ContainerVariable): diff --git a/sot/opcode_translator/executor/variables/__init__.py b/sot/opcode_translator/executor/variables/__init__.py index 214625d2..991bb8a5 100644 --- a/sot/opcode_translator/executor/variables/__init__.py +++ b/sot/opcode_translator/executor/variables/__init__.py @@ -43,6 +43,7 @@ from .iter import ( # noqa: F401 EnumerateVariable, IterVariable, + MapVariable, SequenceIterVariable, UserDefinedIterVariable, ) diff --git a/sot/opcode_translator/executor/variables/iter.py b/sot/opcode_translator/executor/variables/iter.py index 5113c12d..354cb690 100644 --- a/sot/opcode_translator/executor/variables/iter.py +++ b/sot/opcode_translator/executor/variables/iter.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any -from ....utils import FallbackError +from ....utils import BreakGraphError, FallbackError from ..pycode_generator import PyCodeGen from ..tracker import ConstTracker, DummyTracker from .base import VariableBase from .basic import ConstantVariable -from .container import TupleVariable +from .container import ContainerVariable, TupleVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -133,7 +133,57 @@ def from_iterator(value, graph: FunctionGraph | None, tracker: Tracker): return UserDefinedIterVariable(value, graph, tracker) +class MapVariable(SequenceIterVariable): + """ + MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function + """ + + def __init__(self, func, val_iterator, graph, tracker): + super().__init__(val_iterator, graph, tracker) + self.func = func + + def next(self): + return self.func(self.hold.next()) + + def to_list(self) -> list: + retval = [] + while True: + try: + retval.append(self.func(self.hold.next())) + except StopIteration: + break + return retval + + def has_side_effect(self) -> bool: + return self.hold.has_side_effect() + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + codegen.gen_load_global("map", push_null=True) + self.func.reconstruct(codegen) + self.hold.reconstruct(codegen) + codegen.gen_call_function(2) + + @staticmethod + def from_iterator( + func, value, graph: FunctionGraph | None, tracker: Tracker + ): + iter_variable = ( + value.get_iter() if isinstance(value, ContainerVariable) else value + ) + + if isinstance(iter_variable, IterVariable): + return MapVariable(func, iter_variable, graph, tracker) + else: + return UserDefinedIterVariable(value, graph, tracker) + + # what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph class UserDefinedIterVariable(IterVariable): def __init__(self, obj, graph, tracker): super().__init__(obj, graph, tracker) + + def next(self): + raise BreakGraphError("Break graph when using user defined iterator") diff --git a/tests/test_map.py b/tests/test_map.py new file mode 100644 index 00000000..0b7560c7 --- /dev/null +++ b/tests/test_map.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import unittest +from typing import Iterable + +from test_case_base import TestCaseBase, strict_mode_guard + +import sot +from sot.psdb import check_no_breakgraph + + +def double_num(num: float | int): + return num * 2 + + +def double_num_with_breakgraph(num: float | int): + sot.psdb.breakgraph() + return num * 2 + + +@check_no_breakgraph +def test_map_list(x: list): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_list_comprehension(x: list): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_tuple(x: tuple): + return tuple(map(double_num, x)) + + +@check_no_breakgraph +def test_map_tuple_comprehension(x: tuple): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_range(x: Iterable): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_range_comprehension(x: Iterable): + return [i for i in map(double_num, x)] # noqa: C416 + + +def add_dict_prefix(key: str): + return f"dict_{key}" + + +@check_no_breakgraph +def test_map_dict(x: dict): + return list(map(add_dict_prefix, x)) + + +@check_no_breakgraph +def test_map_dict_comprehension(x: dict): + return [i for i in map(add_dict_prefix, x)] # noqa: C416 + + +def test_map_list_with_breakgraph(x: list): + return list(map(double_num_with_breakgraph, x)) + + +@check_no_breakgraph +def test_map_unpack(x: list): + a, b, c, d = map(double_num, x) + return a, b, c, d + + +@check_no_breakgraph +def test_map_for_loop(x: list): + res = 0 + for i in map(double_num, x): + res += i + return res + + +class TestMap(TestCaseBase): + def test_map(self): + self.assert_results(test_map_list, [1, 2, 3, 4]) + self.assert_results(test_map_tuple, (1, 2, 3, 4)) + self.assert_results(test_map_range, range(5)) + self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3}) + + def test_map_comprehension(self): + self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) + self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) + self.assert_results(test_map_range_comprehension, range(5)) + self.assert_results( + test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} + ) + + def test_map_with_breakgraph(self): + with strict_mode_guard(0): + self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) + + def test_map_unpack(self): + self.assert_results(test_map_unpack, [1, 2, 3, 4]) + + def test_map_for_loop(self): + self.assert_results(test_map_for_loop, [7, 8, 9, 10]) + + +if __name__ == "__main__": + unittest.main()