Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Add RangeVariable, EnumerateVariable #222

Closed
wants to merge 16 commits into from
13 changes: 8 additions & 5 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@
DictIterVariable,
DictVariable,
DummyVariable,
EnumerateVariable,
IterVariable,
ListVariable,
MethodVariable,
RangeVariable,
SequenceIterVariable,
TensorIterVariable,
TensorVariable,
Expand Down Expand Up @@ -1137,7 +1139,7 @@ def GET_ITER(self, instr: Instruction):
if isinstance(source_obj, IterVariable):
return self.push(source_obj)

if isinstance(source_obj, (ListVariable, TupleVariable)):
if isinstance(source_obj, (ListVariable, TupleVariable, RangeVariable)):
self.push(
SequenceIterVariable(
source_obj, self._graph, GetIterTracker(source_obj)
Expand Down Expand Up @@ -1661,9 +1663,6 @@ def _break_graph_in_for_loop(
)

# 5.2 load loop body inputs
def update_locals(name, variable):
self._locals[name] = variable
return variable

for name in loop_inputs[:-1]:
self._graph.pycode_gen.gen_load_fast(name)
Expand Down Expand Up @@ -1722,6 +1721,7 @@ def _inline_call_for_loop(
self._graph,
DanglingTracker(),
)

input_vars = [self._locals[name] for name in inputs[:-1]] + [iterator]
ret = fn(*input_vars)
for name, val in zip(inputs[:-1], ret[:-1]):
Expand Down Expand Up @@ -1762,11 +1762,14 @@ def FOR_ITER(self, instr):
self._graph.add_global_guarded_variable(iterator)

# TODO need support TensorIterVariable.next

try:
if not isinstance(
iterator, (SequenceIterVariable, DictIterVariable)
iterator,
(SequenceIterVariable, DictIterVariable, EnumerateVariable),
):
raise BreakGraphError()

backup_iter_idx = iterator.idx
self._inline_call_for_loop(iterator, instr)
self._lasti = self.indexof(instr.jump_to)
Expand Down
6 changes: 5 additions & 1 deletion sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .variables import (
CellVariable,
DictIterVariable,
EnumerateVariable,
IterVariable,
SequenceIterVariable,
)
Expand Down Expand Up @@ -207,7 +208,10 @@ def FOR_ITER(self, instr):
self._graph.add_global_guarded_variable(iterator)

# simplely get next
if isinstance(iterator, (SequenceIterVariable, DictIterVariable)):
if isinstance(
iterator,
(SequenceIterVariable, DictIterVariable, EnumerateVariable),
):
try:
self.push(iterator.next())
except StopIteration:
Expand Down
132 changes: 123 additions & 9 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
)
from .dispatcher import Dispatcher
from .tracker import DummyTracker
from .variables import VariableBase, VariableFactory
from .variables import (
ConstantVariable,
EnumerateVariable,
VariableBase,
VariableFactory,
)

if TYPE_CHECKING:
from .variables import (
ConstantVariable,
DataVariable,
NumpyVariable,
TensorVariable,
)
from .variables import DataVariable, NumpyVariable, TensorVariable


# just a function for operator.in
Expand Down Expand Up @@ -67,6 +67,31 @@ def operator_BAD(left, right):
)

# dict
Dispatcher.register(
dict,
("DictVariable",),
{},
lambda var: VariableFactory.from_value(
dict(var.get_wrapped_items()),
graph=var.graph,
tracker=DummyTracker([var]),
),
)
Dispatcher.register(
dict,
("ListVariable | TupleVariable",),
{},
lambda var: VariableFactory.from_value(
{
key_var.get_value(): value_var
for key_var, value_var in var.get_wrapped_items()
},
graph=var.graph,
tracker=DummyTracker([var]),
),
)


Dispatcher.register(
dict.keys,
("DictVariable",),
Expand All @@ -93,6 +118,29 @@ def operator_BAD(left, right):
lambda var, other: var.update(other),
)
# list
Dispatcher.register(
list,
("ContainerVariable | EnumerateVariable",),
{},
lambda var: VariableFactory.from_value(
list(var.get_wrapped_items()),
graph=var.graph,
tracker=DummyTracker([var]),
),
)

# tuple
Dispatcher.register(
tuple,
("ContainerVariable | EnumerateVariable",),
{},
lambda var: VariableFactory.from_value(
tuple(var.get_wrapped_items()),
graph=var.graph,
tracker=DummyTracker([var]),
),
)

Dispatcher.register(
list.extend,
("ListVariable", "ListVariable | TupleVariable"),
Expand Down Expand Up @@ -195,10 +243,68 @@ def operator_BAD(left, right):
# len
Dispatcher.register(
len,
("ContainerVariable",),
("ContainerVariable | PaddleLayerVariable",),
{},
lambda var: var.len(),
)


# range
# stop
Dispatcher.register(
range,
("ConstantVariable",),
{},
lambda stop: VariableFactory.from_value(
range(stop.get_value()), graph=stop.graph, tracker=DummyTracker([stop])
),
)

# start, stop
Dispatcher.register(
range,
("ConstantVariable", "ConstantVariable"),
{},
lambda start, stop: VariableFactory.from_value(
range(start.get_value(), stop.get_value()),
graph=stop.graph,
tracker=DummyTracker([start, stop]),
),
)
# start, stop, step
Dispatcher.register(
range,
("ConstantVariable", "ConstantVariable", "ConstantVariable"),
{},
lambda start, stop, step: VariableFactory.from_value(
range(start.get_value(), stop.get_value(), step.get_value()),
graph=stop.graph,
tracker=DummyTracker([start, stop, step]),
),
)
# TODO(zmh): Modify
# enumerate
Dispatcher.register(
enumerate,
(
"ListVariable | TupleVariable | RangeVariable | DictVariable | TensorVariable | PaddleLayerVariable",
),
{},
lambda var: EnumerateVariable.from_iterator(
var, graph=var.graph, tracker=DummyTracker([var])
),
)

# isinstance
Dispatcher.register(
isinstance,
("VariableBase", "VariableBase"),
{},
lambda left, right: ConstantVariable.wrap_literal(
isinstance(left.get_value(), right.get_value()), left.graph
),
)

# bool
Dispatcher.register(
bool,
Expand All @@ -214,7 +320,7 @@ def operator_BAD(left, right):
)
Dispatcher.register(
operator.truth,
("ContainerVariable",),
("ContainerVariable | TensorVariable",),
{},
lambda var: var.bool(),
)
Expand All @@ -225,6 +331,14 @@ def operator_BAD(left, right):
lambda var: var.bool(),
)

# str
Dispatcher.register(
str,
("ConstantVariable",),
{},
lambda var: var.str(),
)

# getitem
# TODO: Should pass its Variable into the getitem and perform operations such as getting value in the getitem. like this:https://github.com/PaddlePaddle/PaddleSOT/pull/198#discussion_r1241110949
Dispatcher.register(
Expand Down
2 changes: 2 additions & 0 deletions sot/opcode_translator/executor/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
ContainerVariable,
DictVariable,
ListVariable,
RangeVariable,
TupleVariable,
)
from .iter import ( # noqa: F401
DictIterVariable,
EnumerateVariable,
IterVariable,
SequenceIterVariable,
TensorIterVariable,
Expand Down
17 changes: 17 additions & 0 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def bool_not(self):
not bool(self.get_value()), self.graph, DummyTracker([self])
)

def str(self):
return VariableFactory.from_value(
str(self.value), self.graph, DummyTracker([self])
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, ConstTypes):
Expand Down Expand Up @@ -214,6 +219,18 @@ def __init__(
self.var_name = TensorVariable.var_name_generator.next()
self.graph = graph

def __len__(self):
if self.meta.shape[0] == -1:
raise BreakGraphError(
"length of tensor variable with first dimension == -1"
)
return self.meta.shape[0]

def bool(self):
return VariableFactory.from_value(
bool(self.value), self.graph, DummyTracker([self])
)

def get_value(self):
if self.value is None:
raise InnerError("Can not get value from a inner tensor variable.")
Expand Down
6 changes: 6 additions & 0 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def __init__(
super().__init__(layer, graph, tracker)
self.name = self.layer_name_generator.next()

def __len__(self):
return len(self.value)

def len(self):
return ConstantVariable.wrap_literal(len(self), self.graph)

def get_symbol(self) -> Symbol:
return Symbol(self.name)

Expand Down
70 changes: 70 additions & 0 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,76 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
return None


class RangeVariable(ContainerVariable):
def __init__(
self,
val_range: range,
graph: FunctionGraph,
tracker: Tracker,
):
super().__init__(tracker)
self.graph = graph
self.value = val_range

def get_type(self):
return range

def get_value(self):
return self.value

def getitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
)

retval = self.value[key]
return ConstantVariable.wrap_literal(retval, self.graph)

def get_items(self):
size = len(self)
return [self[idx] for idx in range(size)]

def get_wrapped_items(self):
return self.get_items()

def __len__(self):
return len(self.value)

def _reconstruct(self, codegen: PyCodeGen):
codegen.gen_load_global("range")
# The start default value is 0, step is 1
# So we can always construct range with 3 args
codegen.gen_load_const(self.value.start)
codegen.gen_load_const(self.value.stop)
codegen.gen_load_const(self.value.step)
codegen.gen_call_function(3)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, range):
return RangeVariable(value, graph, tracker)
return None

@property
def debug_name(self) -> str:
return ":".join(
[
str(self.value.start) if self.value.start is not None else "",
str(self.value.stop) if self.value.stop is not None else "",
str(self.value.step) if self.value.step is not None else "",
]
)

@debug_name.setter
def debug_name(self, name):
pass

@property
def main_info(self) -> dict[str, Any]:
return {"value": self.value}


class DictVariable(ContainerVariable):
def __init__(
self,
Expand Down
Loading
Loading