diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index b21eed9d1..5db4e0bca 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -24,6 +24,7 @@ from .variables import ( ContainerVariable, DictVariable, + ListVariable, PaddleLayerVariable, TensorVariable, VariableBase, @@ -288,12 +289,22 @@ def _find_tensor_outputs( self, outputs: list[VariableBase] ) -> list[TensorVariable]: output_tensors: list[TensorVariable] = [] + # Find Tensor Variables from outputs. for output in outputs: if isinstance(output.tracker, DummyTracker): if isinstance(output, TensorVariable): output_tensors.append(output) else: + # Guard output that can not be traced. self.add_global_guarded_variable(output) + # Find Tensor Variables from side effects Variables. + for side_effect_var in self.side_effects.variables: + if side_effect_var.proxy.has_changed: + for var in side_effect_var.flatten_items(): + if isinstance(var.tracker, DummyTracker) and isinstance( + var, TensorVariable + ): + output_tensors.append(var) return output_tensors def restore_side_effects(self, variables: list[VariableBase]): @@ -327,3 +338,20 @@ def restore_side_effects(self, variables: list[VariableBase]): self.pycode_gen.gen_pop_top() self.pycode_gen.gen_call_method(1) # call update self.pycode_gen.gen_pop_top() + elif isinstance(var, ListVariable): + # old_list[:] = new_list + + # Reference to the original list. + # load new_list to stack. + var._reconstruct(self.pycode_gen) + # load old_list[:] to stack. + var.reconstruct(self.pycode_gen) + self.pycode_gen.gen_load_const(None) + self.pycode_gen.gen_load_const(None) + self.pycode_gen.gen_build_slice(2) + + # Generate side effects of other variables. + self.restore_side_effects(variables[1:]) + + # Call STROE_SUBSCR to apply side effects. + self.pycode_gen.gen_store_subscr() diff --git a/sot/opcode_translator/executor/mutable_data.py b/sot/opcode_translator/executor/mutable_data.py index 62b82305b..7ec385b31 100644 --- a/sot/opcode_translator/executor/mutable_data.py +++ b/sot/opcode_translator/executor/mutable_data.py @@ -9,41 +9,94 @@ R = TypeVar("R") DataGetter: TypeAlias = Callable[[Any, Any], Any] + MutableDataT = TypeVar("MutableDataT", bound="MutableData") class Mutation: - ... + ABBR: str + + +class MutationSet(Mutation): + """ + Setting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "S" + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return f"MutationSet({self.key}, {self.value})" + + +class MutationDel(Mutation): + """ + Deleting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "D" + + def __init__(self, key): + self.key = key + + def __repr__(self): + return f"MutationDel({self.key})" class MutationNew(Mutation): - def __init__(self, name, value): - self.name = name + """ + Adding a new value. + This mutation is only used for MutableDictLikeData. + """ + + ABBR = "N" + + def __init__(self, key, value): + self.key = key self.value = value def __repr__(self): - return f"MutationNew({self.name}, {self.value})" + return f"MutationNew({self.key}, {self.value})" -class MutationSet(Mutation): - def __init__(self, name, value): - self.name = name +class MutationInsert(Mutation): + """ + Inserting a value. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "I" + + def __init__(self, index, value): + self.index = index self.value = value def __repr__(self): - return f"MutationSet({self.name}, {self.value})" + return f"MutationInsert({self.index}, {self.value})" -class MutationDel(Mutation): - def __init__(self, name): - self.name = name +class MutationPermutate(Mutation): + """ + Permutating all the values. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "P" + + def __init__(self, permutation): + self.permutation = permutation def __repr__(self): - return f"MutationDel({self.name})" + return f"MutationPermutate({self.permutation})" def record_mutation( - mutation_fn: Callable[Concatenate[MutableDictLikeData, P], Mutation] -) -> Callable[Concatenate[MutableDictLikeData, P], None]: + mutation_fn: Callable[Concatenate[MutableDataT, P], Mutation] +) -> Callable[Concatenate[MutableDataT, P], None]: def wrapper(self, *args: P.args, **kwargs: P.kwargs): mutation = mutation_fn(self, *args, **kwargs) self.records.append(mutation) @@ -56,6 +109,8 @@ class MutableData: An intermediate data structure between data and variable, it records all the mutations. """ + read_cache: list[Any] | dict[str, Any] + class Empty: def __repr__(self): return "Empty()" @@ -86,8 +141,27 @@ def get(self, key): def set(self, key, value): raise NotImplementedError() + def apply( + self, mutation: Mutation, write_cache: dict[str, Any] | list[Any] + ): + raise NotImplementedError() + + def reproduce(self, version: int | None = None): + if version is None: + version = self.version + write_cache = self.read_cache.copy() + for mutation in self.records[:version]: + self.apply(mutation, write_cache) + return write_cache + + def __repr__(self) -> str: + records_abbrs = "".join([mutation.ABBR for mutation in self.records]) + return f"{self.__class__.__name__}({records_abbrs})" + class MutableDictLikeData(MutableData): + read_cache: dict[str, Any] + def __init__(self, data: Any, getter: DataGetter): super().__init__(data, getter) self.read_cache = {} @@ -95,7 +169,7 @@ def __init__(self, data: Any, getter: DataGetter): def clear(self): self.read_cache.clear() - def get(self, key): + def get(self, key: Any): # TODO(SigureMo): Optimize performance of this. write_cache = self.reproduce(self.version) if key not in write_cache: @@ -106,13 +180,13 @@ def get_all(self): original_keys = list(self.original_data.keys()) for mutation in self.records: if isinstance(mutation, MutationNew): - original_keys.append(mutation.name) + original_keys.append(mutation.key) elif isinstance(mutation, MutationDel): - original_keys.remove(mutation.name) + original_keys.remove(mutation.key) return {key: self.get(key) for key in original_keys} @record_mutation - def set(self, key, value) -> Mutation: + def set(self, key: Any, value: Any) -> Mutation: is_new = False if self.is_empty(self.get(key)): is_new = True @@ -126,11 +200,11 @@ def delete(self, key): def apply(self, mutation: Mutation, write_cache: dict[str, Any]): if isinstance(mutation, MutationNew): - write_cache[mutation.name] = mutation.value + write_cache[mutation.key] = mutation.value elif isinstance(mutation, MutationSet): - write_cache[mutation.name] = mutation.value + write_cache[mutation.key] = mutation.value elif isinstance(mutation, MutationDel): - write_cache[mutation.name] = MutableData.Empty() + write_cache[mutation.key] = MutableData.Empty() else: raise ValueError(f"Unknown mutation type {mutation}") @@ -141,3 +215,62 @@ def reproduce(self, version: int | None = None): for mutation in self.records[:version]: self.apply(mutation, write_cache) return write_cache + + +class MutableListLikeData(MutableData): + read_cache: list[Any] + + def __init__(self, data: Any, getter: DataGetter): + super().__init__(data, getter) + self.read_cache = [ + self.getter(self.original_data, idx) for idx in range(len(data)) + ] + + def clear(self): + self.read_cache[:] = [] + + @property + def length(self): + return len(self.reproduce()) + + def get(self, key): + write_cache = self.reproduce(self.version) + return write_cache[key] + + def get_all(self): + return self.reproduce(self.version) + + @record_mutation + def set(self, key: int, value: Any): + return MutationSet(self._regularize_index(key), value) + + @record_mutation + def delete(self, key: int): + return MutationDel(self._regularize_index(key)) + + @record_mutation + def insert(self, index: int, value: Any): + return MutationInsert(self._regularize_index(index), value) + + @record_mutation + def permutate(self, permutation: list[int]): + return MutationPermutate(permutation) + + def _regularize_index(self, index: int): + if index < 0: + index += self.length + return index + + def apply(self, mutation: Mutation, write_cache: list[Any]): + if isinstance(mutation, MutationSet): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationDel): + write_cache[:] = ( + write_cache[: mutation.key] + write_cache[mutation.key + 1 :] + ) + elif isinstance(mutation, MutationInsert): + write_cache.insert(mutation.index, mutation.value) + elif isinstance(mutation, MutationPermutate): + write_cache[:] = [write_cache[i] for i in mutation.permutation] + else: + raise ValueError(f"Unknown mutation type {mutation}") diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index e5fbfb90c..2f3eb5b7c 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1313,6 +1313,13 @@ def DICT_MERGE(self, instr: Instruction): self._stack[-instr.arg], dict_value ) + def LIST_APPEND(self, instr: Instruction): + list_value = self.pop() + assert instr.argval > 0 + BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( + self._stack[-instr.arg], list_value + ) + def LIST_EXTEND(self, instr: Instruction): list_value = self.pop() assert instr.argval > 0 diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 06a4bc65b..3bbd99a33 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -496,6 +496,9 @@ def gen_store_fast(self, name): idx = self._code_options["co_varnames"].index(name) self._add_instr("STORE_FAST", arg=idx, argval=name) + def gen_store_subscr(self): + self._add_instr("STORE_SUBSCR") + def gen_subscribe(self): self._add_instr("BINARY_SUBSCR") @@ -508,6 +511,9 @@ def gen_build_list(self, count): def gen_build_map(self, count): self._add_instr("BUILD_MAP", arg=count, argval=count) + def gen_build_slice(self, argc): + self._add_instr("BUILD_SLICE", arg=argc, argval=argc) + def gen_unpack_sequence(self, count): self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) diff --git a/sot/opcode_translator/executor/side_effects.py b/sot/opcode_translator/executor/side_effects.py index 9e05106af..9494dfe7e 100644 --- a/sot/opcode_translator/executor/side_effects.py +++ b/sot/opcode_translator/executor/side_effects.py @@ -23,7 +23,8 @@ def __init__(self): self.variables: list[VariableBase] = [] def record_variable(self, variable: VariableBase): - self.variables.append(variable) + if variable not in self.variables: + self.variables.append(variable) def get_proxy( self, diff --git a/sot/opcode_translator/executor/tracker.py b/sot/opcode_translator/executor/tracker.py index d3a2a4178..fb38405f9 100644 --- a/sot/opcode_translator/executor/tracker.py +++ b/sot/opcode_translator/executor/tracker.py @@ -213,7 +213,7 @@ def gen_instructions(self, codegen: PyCodeGen): codegen.gen_load_const(self.value) def trace_value_from_frame(self): - return StringifyExpression(f"{self.value}", {}) + return StringifyExpression(f"{self.value!r}", {}) def __repr__(self) -> str: return f"ConstTracker(value={self.value})" diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 6c21031ae..9d454f996 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -57,6 +57,54 @@ {}, lambda var, other: var.extend(other), ) +Dispatcher.register( + list.append, + ("ListVariable", "VariableBase"), + {}, + lambda var, other: var.append(other), +) +Dispatcher.register( + list.insert, + ("ListVariable", "ConstantVariable", "VariableBase"), + {}, + lambda var, index, obj: var.insert(index.get_value(), obj), +) +Dispatcher.register( + list.remove, + ("ListVariable", "VariableBase"), + {}, + lambda var, other: var.remove(other), +) +Dispatcher.register( + list.pop, + ("ListVariable", "ConstantVariable"), + {}, + lambda var, other: var.pop(other), +) +Dispatcher.register( + list.pop, + ("ListVariable",), + {}, + lambda var: var.pop(), +) +Dispatcher.register( + list.clear, + ("ListVariable",), + {}, + lambda var: var.clear(), +) +Dispatcher.register( + list.sort, + ("ListVariable",), + {}, + lambda var: var.sort(), +) +Dispatcher.register( + list.reverse, + ("ListVariable",), + {}, + lambda var: var.reverse(), +) Dispatcher.register( operator.add, ("ListVariable", "ListVariable"), @@ -336,7 +384,8 @@ def is_not_func(var: VariableBase, other: VariableBase): # Tensor for unary_fn in UNARY_OPS: # Tensor doesn't support unary +, skip it - if unary_fn in {operator.pos}: + # TODO(SigureMo): deal len and bool + if unary_fn in {operator.pos, len, bool, operator.truth}: continue for magic_method in magic_method_builtin_dispatch(unary_fn): Dispatcher.register( @@ -460,7 +509,7 @@ def data_variable_unary_dispatcher(var: DataVariable, fn): Dispatcher.register( unary_fn, - ("DataVariable"), + ("DataVariable",), {}, partial(data_variable_unary_dispatcher, fn=unary_fn), ) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index a5de240a5..203e68a76 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -440,12 +440,32 @@ def getattr(self, name: str, default=None): f"{self.__class__.__name__} {self} has no attribute {name}" ) attr = getattr(self.value, name) - if inspect.ismethod(attr): + if inspect.ismethod(attr) or ( + hasattr(attr, "__self__") + and inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ) + ): from .callable import MethodVariable + fn = None + if inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ): + class_var = VariableFactory.from_value( + self.get_type(), + self.graph, + GetAttrTracker(self, "__class__"), + ) + fn = VariableFactory.from_value( + getattr(attr.__self__.__class__, name), + self.graph, + GetAttrTracker(class_var, name), + ) return MethodVariable.wrap_method( value=attr, instance=self, + fn=fn, graph=self.graph, tracker=GetAttrTracker(self, name), method_name=name, @@ -513,7 +533,7 @@ def __call__(self, *args, **kwargs): ) else: fn_var = BuiltinVariable( - unbound_method, + self.value, self.graph, GetAttrTracker(class_var, '__call__'), ) diff --git a/sot/opcode_translator/executor/variables/callable.py b/sot/opcode_translator/executor/variables/callable.py index b107ba644..c6bffc294 100644 --- a/sot/opcode_translator/executor/variables/callable.py +++ b/sot/opcode_translator/executor/variables/callable.py @@ -28,7 +28,7 @@ Tracker, ) from .base import VariableBase, VariableFactory -from .basic import ConstantVariable +from .basic import ConstantVariable, ObjectVariable if TYPE_CHECKING: from ..function_graph import FunctionGraph @@ -217,7 +217,7 @@ def wrap_method( value.__func__, graph, DanglingTracker() ) assert isinstance(instance_var, VariableBase) - assert isinstance(fn_var, FunctionVariable) + assert isinstance(fn_var, (FunctionVariable, ObjectVariable)) method_var = MethodVariable( instance_var, fn_var, diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index 1706e08e3..d664cadbf 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -7,7 +7,7 @@ from ....utils import log_do from ....utils.exceptions import InnerError, NotImplementException from ..guard import StringifyExpression -from ..mutable_data import MutableDictLikeData +from ..mutable_data import MutableDictLikeData, MutableListLikeData from ..pycode_generator import PyCodeGen from ..tracker import ( ConstTracker, @@ -24,6 +24,10 @@ class ContainerVariable(VariableBase): + @property + def init_value(self): + return self.value + def get_items(self) -> list[VariableBase]: raise NotImplementException() @@ -56,12 +60,23 @@ def make_stringify_guard(self) -> StringifyExpression: ), ) len_guard = StringifyExpression( - f"len({frame_value_tracer.expr}) == {len(self)}", + f"len({frame_value_tracer.expr}) == {len(self.init_value)}", frame_value_tracer.free_vars, ) - guard_variables = filter( - lambda var: var.tracker.is_traceable(), self.get_items() - ) + if isinstance(self, (ListVariable, TupleVariable)): + guard_variables = filter( + lambda var: var.tracker.is_traceable(), self.proxy.read_cache + ) + elif isinstance(self, DictVariable): + guard_variables = filter( + lambda var: var.tracker.is_traceable(), + filter( + lambda var: not isinstance(var, MutableDictLikeData.Empty), + self.proxy.read_cache.values(), + ), + ) + else: + raise InnerError(f"Unsupported container type: {type(self)}") return reduce( operator.and_, [len_guard] @@ -79,10 +94,20 @@ def __init__( super().__init__(tracker) self.graph = graph # everything in stack is VariableBase, so just accept the input list is ok + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, val_list, self.proxy_getter + ) self.value = val_list + def proxy_getter(self, data, key): + if key < 0 or key >= len(data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + data[key], self.graph, tracker=GetItemTracker(self, key) + ) + def get_value(self): - return [self[idx].get_value() for idx in range(len(self))] + return [item.get_value() for item in self.proxy.get_all()] def get_type(self): return list @@ -108,53 +133,57 @@ def main_info(self) -> dict[str, Any]: } def __len__(self): - return len(self.value) + return self.proxy.length def getitem(self, key): - ''' - we need to make sure that: - before an inplace change happens to ListVariable, - the related items should already be wrapped as VariableBase - - if not, tracker might be set to a wrong elem - ''' - if isinstance(key, VariableBase): - raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + return VariableFactory.from_value( + self.proxy.get_all()[key], + self.graph, + tracker=GetItemTracker(self, key), ) - - retval = self.value[key] - - # if list is an input of funciton, we need make sure __getitem__ returns a VariableBase - retval = VariableFactory.from_value( - retval, self.graph, tracker=GetItemTracker(self, key) - ) - - return retval - - def setitem(self, key, value): - ''' - why setitem is ok: - - case: - def f(x = [t0, t1]) - ... - x[0] = 0 - ... - - 1. if setitem happens after get t0: t0 is a VariableBase (transformed at getitem), so it is ok - 2. if setitem happens before get t0: t0 will not be used - ''' - if isinstance(key, VariableBase): + else: raise InnerError( - f"[{self.__class__.__name__}]: received {key} as key." + f"Unsupported key type {key.__class__.__name__} for ListVariable" ) + def setitem(self, key, value): if not isinstance(value, VariableBase): raise InnerError( f"[{self.__class__.__name__}]: received {value} to set value." ) - self.value[key] = value + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice) and isinstance( + value, (ListVariable, TupleVariable) + ): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value.get_wrapped_items()): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise InnerError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} and value type {value.__class__.__name__} for ListVariable" + ) + + self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) def __delitem__(self, key): @@ -165,30 +194,125 @@ def delitem(self, key): raise InnerError( f"[{self.__class__.__name__}]: received {key} as key to delete." ) - del self.value[key] + self.proxy.delete(key) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def insert(self, index: int, value: VariableBase): + self.proxy.insert(index, value) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def append(self, value: VariableBase): + self.insert(self.proxy.length, value) + self.graph.side_effects.record_variable(self) return ConstantVariable.wrap_literal(None, self.graph) def extend(self, data): - self.value.extend(data.get_wrapped_items()) - return self + for item in data.proxy.get_all(): + self.append(item) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) def concat(self, list_): assert isinstance(list_, ListVariable) - new_list_variable = ListVariable( - self.get_wrapped_items() + list_.get_wrapped_items(), + return ListVariable( + self.proxy.get_all() + list_.proxy.get_all(), self.graph, DummyTracker([self, list_]), ) - return new_list_variable def repeat(self, length): assert isinstance(length, ConstantVariable) - new_list_variable = ListVariable( - self.get_wrapped_items() * length.value, + return ListVariable( + self.proxy.get_all() * length.value, self.graph, DummyTracker([self, length]), ) - return new_list_variable + + def pop(self, index: ConstantVariable | None = None): + if index is None: + index = ConstantVariable.wrap_literal(-1, self.graph) + res = self.proxy.get(index.get_value()) + self.proxy.delete(index.get_value()) + self.graph.side_effects.record_variable(self) + return res + + def copy(self): + return ListVariable( + self.proxy.get_all(), + self.graph, + DummyTracker([self]), + ) + + def clear(self): + for idx in range(self.proxy.length): + self.delitem(0) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def remove(self, value): + for idx in range(self.proxy.length): + if self[idx].get_value() == value.get_value(): + self.delitem(idx) + break + else: + raise InnerError(f"List {self} does not contain {value}") + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def sort(self, key=None, reverse=None): + if ( + key is None + or isinstance(key, ConstantVariable) + and key.get_value() is None + ): + key = VariableFactory.from_value( + lambda x: x, self.graph, DanglingTracker() + ) + if reverse is None: + reverse = ConstantVariable.wrap_literal(False, self.graph) + + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key.get_value()(self.getitem(x).value), + reverse=reverse.get_value(), + ) + self.proxy.permutate(permutation) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + self.graph.side_effects.record_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def getattr(self, name): + from .callable import BuiltinVariable + + method_name_to_builtin_fn = { + "insert": list.insert, + "append": list.append, + "extend": list.extend, + "pop": list.pop, + "copy": list.copy, + "clear": list.clear, + "remove": list.remove, + "sort": list.sort, + "reverse": list.reverse, + } + + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise NotImplementException( + f"attribute {name} for dict is not implemented" + ) @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): @@ -207,8 +331,18 @@ def __init__( ): super().__init__(tracker) self.graph = graph + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, list(val_tuple), self.proxy_getter + ) self.value = val_tuple + def proxy_getter(self, data, key): + if key < 0 or key >= len(data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + data[key], self.graph, tracker=GetItemTracker(self, key) + ) + def get_value(self): return tuple(self[idx].get_value() for idx in range(len(self))) @@ -236,18 +370,24 @@ def main_info(self) -> dict[str, Any]: } def __len__(self): - return len(self.value) + return self.proxy.length def getitem(self, key): - if isinstance(key, VariableBase): + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + return VariableFactory.from_value( + tuple(self.proxy.get_all())[key], + self.graph, + tracker=GetItemTracker(self, key), + ) + else: raise InnerError( - f"[{self.__class__.__name__}]: recieved {key} as key." + f"Unsupported key type {key.__class__.__name__} for TupleVariable" ) - retval = self.value[key] - - return VariableFactory.from_value( - retval, graph=self.graph, tracker=GetItemTracker(self, key) - ) def setitem(self, key, value): raise InnerError( @@ -265,7 +405,7 @@ def delitem(self, key): def concat(self, tuple_): assert isinstance(tuple_, TupleVariable) new_tuple_variable = TupleVariable( - self.get_wrapped_items() + tuple_.get_wrapped_items(), + tuple(self.proxy.get_all() + tuple_.proxy.get_all()), self.graph, DummyTracker([self, tuple_]), ) @@ -274,7 +414,7 @@ def concat(self, tuple_): def repeat(self, length): assert isinstance(length, ConstantVariable) new_tuple_variable = TupleVariable( - self.get_wrapped_items() * length.value, + tuple(self.proxy.get_all()) * length.value, self.graph, DummyTracker([self, length]), ) diff --git a/sot/utils/magic_methods.py b/sot/utils/magic_methods.py index c73617c47..bca42ea46 100644 --- a/sot/utils/magic_methods.py +++ b/sot/utils/magic_methods.py @@ -108,4 +108,5 @@ def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]: return magic_methods elif fn in UNARY_OPS: magic_name = UNARY_OPS_TO_MAGIC_NAMES[fn] + return [MagicMethod(magic_name)] return [] diff --git a/tests/run_all_paddle_ci.sh b/tests/run_all_paddle_ci.sh index 0b783fb68..b826de00a 100644 --- a/tests/run_all_paddle_ci.sh +++ b/tests/run_all_paddle_ci.sh @@ -3,11 +3,7 @@ export STRICT_MODE=0 PADDLE_TEST_BASE=./Paddle/test/dygraph_to_static failed_tests=() disabled_tests=( - ${PADDLE_TEST_BASE}/test_write_python_container.py # side effect - ${PADDLE_TEST_BASE}/test_slice.py # side effect ${PADDLE_TEST_BASE}/test_lac.py # disabled by paddle - ${PADDLE_TEST_BASE}/test_dict.py # side effect - ${PADDLE_TEST_BASE}/test_list.py # side effect ${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle ${PADDLE_TEST_BASE}/test_reinforcement_learning.py # 'CartPoleEnv' object has no attribute 'seed' # tmp = x diff --git a/tests/test-side-effect.py b/tests/test-side-effect.py deleted file mode 100644 index 9a05788d2..000000000 --- a/tests/test-side-effect.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - - -def normal_side_effect_1(tensor_x, list_a): - tensor_x = tensor_x + 1 - list_a.append(12) - return tensor_x, list_a - - -def normal_side_effect_2(tensor_x, list_a): - tensor_x = tensor_x + 1 - list_a.append(tensor_x) - return tensor_x, list_a - - -def normal_side_effect_3(list_a): - """list index tracker.""" - del list_a[0] - return list_a[0] - - -def normal_side_effect_5(list_a): - """nested side effect""" - inner_list = [] - inner_list.append(list_a) - inner_list[-1].append(12) - return 12 - # check list_a - - -a = 12 - - -def normal_size_effect_6(tensor_x): - """global""" - global a - a = 1 - return tensor_x + a - - -class CustomObject: - def __init__(self): - self.x = 0 - - -def normal_size_effect_7(cus_obj, t): - """object side effect.""" - t = t + 1 - cus_obj.x = t - return t, cus_obj - - -# class TestNumpyAdd(TestCaseBase): -# @strict_mode_guard(0) -# def test_numpy_add(self): -# x = paddle.to_tensor([2]) -# y = paddle.to_tensor([3]) -# self.assert_results(numpy_add, x, y) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 8b15073be..ad31ad431 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -50,8 +50,7 @@ def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) - # TODO(SigureMo) SideEffects have not been implemented yet, we need to skip them - # self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2)) self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) diff --git a/tests/test_12_for_loop.py b/tests/test_12_for_loop.py index ea258d6c4..e6da87635 100644 --- a/tests/test_12_for_loop.py +++ b/tests/test_12_for_loop.py @@ -148,13 +148,11 @@ def test_resume_stack(self): def run_list_comp(x): out = [s.chunk(2, axis=1) for s in x] - out = [s.chunk(1, axis=1) for s in x] return out class TestListComp(TestCaseBase): - # TODO(SigureMo): Support LIST_APPEND - def error_test_list_comp(self): + def test_list_comp(self): x = [paddle.randn([1, 4]), paddle.randn([1, 4])] self.assert_results(run_list_comp, x) diff --git a/tests/test_15_slice.py b/tests/test_15_slice.py index fba83d99e..2d550038b 100644 --- a/tests/test_15_slice.py +++ b/tests/test_15_slice.py @@ -29,14 +29,14 @@ def build_tuple_slice_with_step(x: list, y: paddle.Tensor): return x[0] + y -class TestExecutor(TestCaseBase): +class TestSlice(TestCaseBase): def test_simple(self): x = list(range(10)) y = paddle.arange(10) - self.assert_results(build_list_slice, x, y) - self.assert_results(build_list_slice_with_step, x, y) - self.assert_results(build_tuple_slice, x, y) - self.assert_results(build_tuple_slice_with_step, x, y) + self.assert_results_with_side_effects(build_list_slice, x, y) + self.assert_results_with_side_effects(build_list_slice_with_step, x, y) + self.assert_results_with_side_effects(build_tuple_slice, x, y) + self.assert_results_with_side_effects(build_tuple_slice_with_step, x, y) class MyLayer(paddle.nn.Layer): @@ -58,7 +58,7 @@ def layer_list_slice(layer, x): class TestLayerList(TestCaseBase): - def test_run(self): + def test_layer_list_slice(self): layer = MyLayer() x = paddle.randn([5, 10]) self.assert_results(layer_list_slice, layer, x) diff --git a/tests/test_mutable_data.py b/tests/test_mutable_data.py index 3d8930f26..1bdec4f7c 100644 --- a/tests/test_mutable_data.py +++ b/tests/test_mutable_data.py @@ -3,6 +3,7 @@ from sot.opcode_translator.executor.mutable_data import ( MutableData, MutableDictLikeData, + MutableListLikeData, ) @@ -48,6 +49,96 @@ def delitem(self, key): self.proxy.delete(key) +class ListVariable(VariableBase): + def __init__(self, data): + self.data = data + self.proxy = MutableListLikeData(data, ListVariable.proxy_getter) + + @staticmethod + def proxy_getter(data, key): + if key < 0 or key >= len(data): + return MutableData.Empty() + return ConstVariable(data[key]) + + def getitem(self, key): + if isinstance(key, int): + res = self.proxy.get(key) + if isinstance(res, MutableData.Empty): + raise IndexError(f"Index {key} out of range") + return res + elif isinstance(key, slice): + return self.proxy.get_all()[key] + else: + raise TypeError(f"Invalid key type {type(key)}") + + def __getitem__(self, key): + return self.getitem(key) + + def setitem(self, key, value): + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise ValueError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + + def delitem(self, key): + self.proxy.delete(key) + + def insert(self, index, value): + self.proxy.insert(index, value) + + def append(self, value): + self.proxy.insert(self.proxy.length, value) + + def extend(self, value): + for item in value: + self.append(item) + + def pop(self, index=-1): + res = self.getitem(index) + self.delitem(index) + return res + + def clear(self): + for i in range(self.proxy.length): + self.delitem(0) + + def remove(self, value): + for i in range(self.proxy.length): + if self.getitem(i) == value: + self.delitem(i) + return + raise ValueError(f"Value {value} not found") + + def sort(self, key=None, reverse=False): + if key is None: + key = lambda x: x + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key(self.getitem(x).value), reverse=reverse + ) + self.proxy.permutate(permutation) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + + class TestMutableDictLikeVariable(unittest.TestCase): def test_getitem(self): data = {"a": 1, "b": 2} @@ -76,5 +167,174 @@ def test_keys(self): self.assertEqual(list(var.proxy.get_all().keys()), ["a", "b"]) +class TestMutableListLikeVariable(unittest.TestCase): + def test_getitem(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(2)) + self.assertEqual(var.getitem(2), ConstVariable(3)) + + def test_getitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var.getitem(slice(0, 3)), + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(4, 1, -1)), + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(1, 5, 2)), + [ConstVariable(2), ConstVariable(4)], + ) + + def test_getitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var[0:3], + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var[4:1:-1], + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var[1:5:2], + [ConstVariable(2), ConstVariable(4)], + ) + + def test_setitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.setitem(0, ConstVariable(4)) + self.assertEqual(var.getitem(0), ConstVariable(4)) + var.append(ConstVariable(5)) + self.assertEqual(var.getitem(3), ConstVariable(5)) + + def test_setitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(0, 3), [ConstVariable(4), ConstVariable(5)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 4, 5, 6, 7]], + ) + var.setitem( + slice(4, 1, -1), + [ConstVariable(8), ConstVariable(9), ConstVariable(10)], + ) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 10, 9, 8, 7]], + ) + + def test_setitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(2, 5, 2), [ConstVariable(8), ConstVariable(9)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [1, 2, 8, 4, 9, 6, 7]], + ) + + def test_delitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.delitem(0) + with self.assertRaises(IndexError): + var.getitem(2) + var.pop() + with self.assertRaises(IndexError): + var.getitem(1) + + def test_insert(self): + data = [1, 2, 3] + var = ListVariable(data) + var.insert(0, ConstVariable(4)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 2, 3]], + ) + var.insert(2, ConstVariable(5)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 5, 2, 3]], + ) + + def test_append(self): + data = [1, 2, 3] + var = ListVariable(data) + var.append(ConstVariable(4)) + self.assertEqual(var.getitem(3), ConstVariable(4)) + + def test_extend(self): + data = [1, 2, 3] + var = ListVariable(data) + var.extend([ConstVariable(4), ConstVariable(5)]) + self.assertEqual(var.getitem(3), ConstVariable(4)) + self.assertEqual(var.getitem(4), ConstVariable(5)) + + def test_pop(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.pop(), ConstVariable(3)) + self.assertEqual(var.pop(0), ConstVariable(1)) + + def test_clear(self): + data = [1, 2, 3] + var = ListVariable(data) + var.clear() + self.assertEqual(var.proxy.length, 0) + + def test_remove(self): + data = [1, 2, 3] + var = ListVariable(data) + var.remove(ConstVariable(2)) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(3)) + with self.assertRaises(ValueError): + var.remove(ConstVariable(2)) + + def test_sort(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, 1, 2, 3, 4, 5]], + ) + + def test_sort_with_key(self): + data = [-1, -4, 2, 0, 5, -3] + var = ListVariable(data) + var.sort(key=lambda x: x**2) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, -1, 2, -3, -4, 5]], + ) + + def test_sort_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort(reverse=True) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 4, 3, 2, 1, 0]], + ) + + def test_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.reverse() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 1, 4, 0, 3, 2]], + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_side_effects.py b/tests/test_side_effects.py index 3e971bf61..247c46378 100644 --- a/tests/test_side_effects.py +++ b/tests/test_side_effects.py @@ -5,6 +5,8 @@ from test_case_base import TestCaseBase import paddle +from sot import symbolic_translate +from sot.utils import InnerError def dict_setitem(x): @@ -42,6 +44,67 @@ def dict_nested_2(x): return a, b +def list_append_int(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(12) + return tensor_x, list_a + + +def list_append_tensor(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(tensor_x) + return tensor_x, list_a + + +def list_delitem(list_a): + del list_a[0] + return list_a[0] + + +def list_extend(list_a): + list_a.extend([1, 2, 3]) + return list_a[0] + + +def list_nested(list_a): + inner_list = [] + inner_list.append(list_a) + inner_list[-1].append(12) + return 12 + + +def list_insert(list_a): + list_a.insert(0, 1) + return list_a[0] + + +def list_remove(list_a): + list_a.remove(1) + return list_a[0] + + +def list_pop(list_a): + list_a.pop(0) + list_a.pop() + list_a.pop(1) + return list_a[0] + + +def list_clear(list_a): + list_a.clear() + return list_a + + +def list_sort(list_a): + list_a.sort() + return list_a + + +def list_reverse(list_a): + list_a.reverse() + return list_a + + def slice_in_for_loop(x, iter_num=3): x = paddle.to_tensor(x) a = [] @@ -57,6 +120,30 @@ def slice_in_for_loop(x, iter_num=3): return out +# TODO: Global SideEffect +a = 12 + + +def normal_size_effect_6(tensor_x): + """global""" + global a + a = 1 + return tensor_x + a + + +# TODO: Object SideEffect +class CustomObject: + def __init__(self): + self.x = 0 + + +def object_attr(cus_obj, t): + """object side effect.""" + t = t + 1 + cus_obj.x = t + return t, cus_obj + + class TestDictSideEffect(TestCaseBase): def test_dict_setitem(self): self.assert_results_with_side_effects( @@ -97,11 +184,65 @@ def test_dict_nested_2(self): class TestListSideEffect(TestCaseBase): - # TODO(SigureMo): Support list side effects. - def error_test_slice_in_for_loop(self): + def test_list_append(self): + self.assert_results_with_side_effects( + list_append_int, paddle.to_tensor(1), [1, 2, 3] + ) + self.assert_results_with_side_effects( + list_append_tensor, paddle.to_tensor(2), [1, 2, 3] + ) + + def test_list_delitem(self): + self.assert_results_with_side_effects(list_delitem, [1, 2, 3]) + + def test_list_extend(self): + self.assert_results_with_side_effects( + list_extend, [1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + + def test_list_insert(self): + self.assert_results_with_side_effects(list_insert, [1, 2, 3]) + self.assert_results_with_side_effects( + list_insert, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_remove(self): + self.assert_results_with_side_effects(list_remove, [1, 1, 1]) + self.assert_results_with_side_effects(list_remove, [0, 1, 2]) + with self.assertRaises(InnerError): + symbolic_translate(list_remove)([0, 2, 4]) + + def test_list_pop(self): + self.assert_results_with_side_effects(list_pop, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_pop, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_clear(self): + self.assert_results_with_side_effects(list_clear, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_clear, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_sort(self): + self.assert_results_with_side_effects(list_sort, [2, 1, 7, 3, 4, 6]) + self.assert_results_with_side_effects( + list_sort, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_reverse(self): + self.assert_results_with_side_effects(list_reverse, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_reverse, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_slice_in_for_loop(self): x = 2 self.assert_results_with_side_effects(slice_in_for_loop, x) + def test_list_nested(self): + self.assert_results_with_side_effects(list_nested, [1, 2, 3]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sir_rollback.py b/tests/test_sir_rollback.py index 7bd259f9d..b0ff7bc36 100644 --- a/tests/test_sir_rollback.py +++ b/tests/test_sir_rollback.py @@ -45,5 +45,26 @@ def test_rollback(self): assert len(graph.sir_ctx.TOS.statements) == original_length +def fn_with_side_effects_inner(x, y): + x[0] += 10 + x[1] += 20 + x[2] -= 10 + print(y) # print will cause breakgraph + return + + +def fn_with_side_effects(x, y): + x[0] += 1 + fn_with_side_effects_inner(x, y) + return x[0] + y + + +class TestSideEffectRollback(TestCaseBase): + def test_side_effect_rollback(self): + self.assert_results_with_side_effects( + fn_with_side_effects, [1, 2, 3], paddle.to_tensor(42) + ) + + if __name__ == "__main__": unittest.main()