Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[Dispatcher] Support map (#386)
Browse files Browse the repository at this point in the history
Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
ranchongzhi and SigureMo authored Sep 27, 2023
1 parent 58159d7 commit 5fc1ee6
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 25 deletions.
31 changes: 8 additions & 23 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,29 +1478,14 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction):

def UNPACK_SEQUENCE(self, instr: Instruction):
sequence = self.stack.pop()

'''
TODO: To unpack iterator
To unpack is easy, just like:
seq = tuple(sequence.get_py_value())
But what is the `source` when iterator returned a value ?
'''
if not isinstance(
sequence, (ListVariable, TupleVariable, TensorVariable)
):
raise FallbackError(f"Unpack {sequence} is not implemented.")

assert (
len(sequence) == instr.arg
), f"Want unpack {sequence} to {instr.arg}, but the len is {len(sequence)}."

for i in range(instr.arg - 1, -1, -1):
self.stack.push(
BuiltinVariable(
operator.getitem, self._graph, DanglingTracker()
)(sequence, i)
)
seq_iter = BuiltinVariable(iter, self._graph, DanglingTracker())(
sequence
)
unpacked = []
for _ in range(instr.arg):
unpacked.append(seq_iter.next())
for item in reversed(unpacked):
self.stack.push(item)

def UNPACK_EX(self, instr: Instruction):
getitem = BuiltinVariable(
Expand Down
14 changes: 14 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DictVariable,
EnumerateVariable,
ListVariable,
MapVariable,
NumpyVariable,
RangeVariable,
SliceVariable,
Expand Down Expand Up @@ -485,6 +486,19 @@ def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableB
)


# map
Dispatcher.register(
map,
(
"CallableVariable",
"VariableBase",
),
lambda fn, var: MapVariable.from_iterator(
fn, var, graph=var.graph, tracker=DummyTracker([var])
),
)


# reversed
@Dispatcher.register_decorator(reversed)
def dispatch_reversed(var: ContainerVariable):
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .iter import ( # noqa: F401
EnumerateVariable,
IterVariable,
MapVariable,
SequenceIterVariable,
UserDefinedIterVariable,
)
54 changes: 52 additions & 2 deletions sot/opcode_translator/executor/variables/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import TYPE_CHECKING, Any

from ....utils import FallbackError
from ....utils import BreakGraphError, FallbackError
from ..pycode_generator import PyCodeGen
from ..tracker import ConstTracker, DummyTracker
from .base import VariableBase
from .basic import ConstantVariable
from .container import TupleVariable
from .container import ContainerVariable, TupleVariable

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
Expand Down Expand Up @@ -133,7 +133,57 @@ def from_iterator(value, graph: FunctionGraph | None, tracker: Tracker):
return UserDefinedIterVariable(value, graph, tracker)


class MapVariable(SequenceIterVariable):
"""
MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function
"""

def __init__(self, func, val_iterator, graph, tracker):
super().__init__(val_iterator, graph, tracker)
self.func = func

def next(self):
return self.func(self.hold.next())

def to_list(self) -> list:
retval = []
while True:
try:
retval.append(self.func(self.hold.next()))
except StopIteration:
break
return retval

def has_side_effect(self) -> bool:
return self.hold.has_side_effect()

def _reconstruct(self, codegen: PyCodeGen):
if self.has_side_effect():
super()._reconstruct(codegen)
else:
codegen.gen_load_global("map", push_null=True)
self.func.reconstruct(codegen)
self.hold.reconstruct(codegen)
codegen.gen_call_function(2)

@staticmethod
def from_iterator(
func, value, graph: FunctionGraph | None, tracker: Tracker
):
iter_variable = (
value.get_iter() if isinstance(value, ContainerVariable) else value
)

if isinstance(iter_variable, IterVariable):
return MapVariable(func, iter_variable, graph, tracker)
else:
return UserDefinedIterVariable(value, graph, tracker)


# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph
class UserDefinedIterVariable(IterVariable):
def __init__(self, obj, graph, tracker):
super().__init__(obj, graph, tracker)

def next(self):
raise BreakGraphError("Break graph when using user defined iterator")
110 changes: 110 additions & 0 deletions tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

import unittest
from typing import Iterable

from test_case_base import TestCaseBase, strict_mode_guard

import sot
from sot.psdb import check_no_breakgraph


def double_num(num: float | int):
return num * 2


def double_num_with_breakgraph(num: float | int):
sot.psdb.breakgraph()
return num * 2


@check_no_breakgraph
def test_map_list(x: list):
return list(map(double_num, x))


@check_no_breakgraph
def test_map_list_comprehension(x: list):
return [i for i in map(double_num, x)] # noqa: C416


@check_no_breakgraph
def test_map_tuple(x: tuple):
return tuple(map(double_num, x))


@check_no_breakgraph
def test_map_tuple_comprehension(x: tuple):
return [i for i in map(double_num, x)] # noqa: C416


@check_no_breakgraph
def test_map_range(x: Iterable):
return list(map(double_num, x))


@check_no_breakgraph
def test_map_range_comprehension(x: Iterable):
return [i for i in map(double_num, x)] # noqa: C416


def add_dict_prefix(key: str):
return f"dict_{key}"


@check_no_breakgraph
def test_map_dict(x: dict):
return list(map(add_dict_prefix, x))


@check_no_breakgraph
def test_map_dict_comprehension(x: dict):
return [i for i in map(add_dict_prefix, x)] # noqa: C416


def test_map_list_with_breakgraph(x: list):
return list(map(double_num_with_breakgraph, x))


@check_no_breakgraph
def test_map_unpack(x: list):
a, b, c, d = map(double_num, x)
return a, b, c, d


@check_no_breakgraph
def test_map_for_loop(x: list):
res = 0
for i in map(double_num, x):
res += i
return res


class TestMap(TestCaseBase):
def test_map(self):
self.assert_results(test_map_list, [1, 2, 3, 4])
self.assert_results(test_map_tuple, (1, 2, 3, 4))
self.assert_results(test_map_range, range(5))
self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3})

def test_map_comprehension(self):
self.assert_results(test_map_list_comprehension, [1, 2, 3, 4])
self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4))
self.assert_results(test_map_range_comprehension, range(5))
self.assert_results(
test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3}
)

def test_map_with_breakgraph(self):
with strict_mode_guard(0):
self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4])

def test_map_unpack(self):
self.assert_results(test_map_unpack, [1, 2, 3, 4])

def test_map_for_loop(self):
self.assert_results(test_map_for_loop, [7, 8, 9, 10])


if __name__ == "__main__":
unittest.main()

0 comments on commit 5fc1ee6

Please sign in to comment.