diff --git a/sot/__init__.py b/sot/__init__.py index 49c814b5b..c96104a9e 100644 --- a/sot/__init__.py +++ b/sot/__init__.py @@ -1,3 +1,4 @@ +from . import psdb # noqa: F401 from .opcode_translator.breakpoint import BM, add_breakpoint, add_event from .opcode_translator.skip_files import skip_function from .translate import symbolic_translate diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 29706453c..7fff91410 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -34,6 +34,8 @@ GlobalDelSideEffectRestorer, GlobalSetSideEffectRestorer, ListSideEffectRestorer, + ObjDelSideEffectRestorer, + ObjSetSideEffectRestorer, SideEffectRestorer, SideEffects, ) @@ -552,7 +554,7 @@ def _find_tensor_outputs( self.add_global_guarded_variable(output) # Find Tensor Variables from side effects Variables. for side_effect_var in self.side_effects.proxy_variables: - if side_effect_var.proxy.has_changed: + if isinstance(side_effect_var, (ListVariable, DictVariable)): for var in side_effect_var.flatten_items(): if ( isinstance(var.tracker, DummyTracker) @@ -560,16 +562,21 @@ def _find_tensor_outputs( and side_effect_var.tracker.is_traceable() ): output_tensors.add(var) - if isinstance(var, GlobalVariable): - for record in var.proxy.records: - if ( - isinstance(record, (MutationSet, MutationNew)) - and isinstance( - record.value.tracker, DummyTracker - ) - and isinstance(record.value, TensorVariable) - ): - output_tensors.add(record.value) + else: + if isinstance(side_effect_var, GlobalVariable): + proxy_records = side_effect_var.proxy.records + elif side_effect_var.tracker.is_traceable(): + # for attr side effect + proxy_records = side_effect_var.attr_proxy.records + else: + continue + for record in proxy_records: + if isinstance(record, (MutationSet, MutationNew)): + for var in record.value.flatten_items(): + if isinstance( + var.tracker, DummyTracker + ) and isinstance(var, TensorVariable): + output_tensors.add(var) # Find Tensor in print_stmts for print_stmt in self._print_variables: for var in print_stmt.flatten_items(): @@ -645,7 +652,24 @@ def restore_side_effects(self, variables: list[VariableBase]): restorers.append( GlobalDelSideEffectRestorer(record.key) ) - # TODO: support attribute restore + else: + for record in var.attr_proxy.records[::-1]: + if isinstance(record, (MutationSet, MutationNew)): + restorers.append( + ObjSetSideEffectRestorer( + var, + record.key, + record.value, + ) + ) + elif isinstance(record, MutationDel): + restorers.append( + ObjDelSideEffectRestorer( + var, + record.key, + ) + ) + for restorer in restorers: restorer.pre_gen(self.pycode_gen) for restorer in restorers[::-1]: diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index b68f18f91..0ac31c104 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -976,25 +976,22 @@ def LOAD_METHOD(self, instr: Instruction): self.stack.push(NullVariable()) self.stack.push(method) - def STORE_ATTR(self, instr): + def STORE_ATTR(self, instr: Instruction): obj = self.stack.pop() val = self.stack.pop() key = self._code.co_names[instr.arg] - if isinstance(obj, TensorVariable): - # support tensor variable store attr, like: - # t.stop_gradient = True - obj.graph.call_tensor_method( - "__setattr__", - obj, - VariableFactory().from_value( - key, self._graph, ConstTracker(key) - ), - val, - ) - else: - raise BreakGraphError( - f"STORE_ATTR don't support {type(obj)}.{key}={val}" - ) + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable( + setattr, self._graph, DummyTracker([obj, key_var, val]) + )(obj, key_var, val) + + def DELETE_ATTR(self, instr: Instruction): + obj = self.stack.pop() + key = instr.argval + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable(delattr, self._graph, DummyTracker([obj, key_var]))( + obj, key_var + ) def STORE_DEREF(self, instr: Instruction): if sys.version_info >= (3, 11): diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 2e9e51e3d..f649a26c9 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -762,6 +762,18 @@ def gen_load_attr(self, name: str): idx = self._code_options["co_names"].index(name) self._add_instr("LOAD_ATTR", arg=idx, argval=name) + def gen_store_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("STORE_ATTR", arg=idx, argval=name) + + def gen_delete_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("DELETE_ATTR", arg=idx, argval=name) + def gen_load_method(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) diff --git a/sot/opcode_translator/executor/side_effects.py b/sot/opcode_translator/executor/side_effects.py index ac6e462a6..a92f7b585 100644 --- a/sot/opcode_translator/executor/side_effects.py +++ b/sot/opcode_translator/executor/side_effects.py @@ -180,3 +180,41 @@ def pre_gen(self, codegen: PyCodeGen): def post_gen(self, codegen: PyCodeGen): codegen.gen_delete_global(self.name) + + +class ObjSetSideEffectRestorer(SideEffectRestorer): + """ + obj.attr = new_value + """ + + def __init__(self, obj: VariableBase, name: str, var: VariableBase): + super().__init__() + self.obj = obj + self.name = name + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # value + self.var.reconstruct(codegen) + # obj + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_store_attr(self.name) + + +class ObjDelSideEffectRestorer(SideEffectRestorer): + """ + del obj.attr + """ + + def __init__(self, obj: VariableBase, name: str): + super().__init__() + self.obj = obj + self.name = name + + def pre_gen(self, codegen: PyCodeGen): + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_delete_attr(self.name) diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index f837079b0..43a1cd493 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -42,6 +42,11 @@ from .variables import DataVariable, TensorVariable +def add_guard(var: VariableBase): + var.graph.add_global_guarded_variable(var) + return var + + def raise_err_handle(error): def inner(*args, **kwargs): raise error @@ -52,7 +57,7 @@ def inner(*args, **kwargs): # slice Dispatcher.register( slice, - ("VariableBase"), + ("VariableBase",), lambda stop: SliceVariable( slice(stop), graph=stop.graph, @@ -406,26 +411,36 @@ def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableB Dispatcher.register( getattr, ("VariableBase", "ConstantVariable", optional("VariableBase")), - lambda var, name, default=None: ( - var.graph.add_global_guarded_variable(name), - var.getattr(name.get_py_value(), default), - )[1], + lambda var, name, default=None: var.getattr( + add_guard(name).get_py_value(), default + ), +) + +# hasattr +Dispatcher.register( + hasattr, + ("VariableBase", "ConstantVariable"), + lambda var, name: var.hasattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + delattr, + ("VariableBase", "VariableBase"), + lambda var, name: var.delattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + setattr, + ("VariableBase", "VariableBase", "VariableBase"), + lambda var, name, value: var.setattr(add_guard(name).get_py_value(), value), ) + # len Dispatcher.register( len, ("ContainerVariable | UserDefinedLayerVariable",), lambda var: var.len(), ) -# hasattr -Dispatcher.register( - hasattr, - ("VariableBase", "ConstantVariable"), - lambda var, name: ( - var.graph.add_global_guarded_variable(name), - var.hasattr(name.get_py_value()), - )[1], -) # range # stop diff --git a/sot/opcode_translator/executor/variables/__init__.py b/sot/opcode_translator/executor/variables/__init__.py index 6fee7f0f2..bb79481e3 100644 --- a/sot/opcode_translator/executor/variables/__init__.py +++ b/sot/opcode_translator/executor/variables/__init__.py @@ -22,6 +22,7 @@ from .callable import ( # noqa: F401 BuiltinVariable, CallableVariable, + ClassVariable, FunctionVariable, LayerVariable, MethodVariable, diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index 118c81879..1cddc6c39 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -490,6 +490,20 @@ def getattr(self, name: str, default=None): ) return result + def setattr(self, key: str, value): + from .basic import ConstantVariable + + self.attr_proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def delattr(self, key: str): + from .basic import ConstantVariable + + self.attr_proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + def __setitem__(self, key, value): return self.setitem(key, value) diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 610688142..e1f51b72f 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -510,6 +510,19 @@ def getattr(self, name: str, default=None): else: raise HasNoAttributeError(f"Unknown Tensor attribute: {name}") + def setattr(self, key, val): + # support tensor variable store attr, like: + # t.stop_gradient = True + self.graph.call_tensor_method( + "__setattr__", + self, + VariableFactory().from_value(key, self.graph, ConstTracker(key)), + val, + ) + + def delattr(self, key): + raise BreakGraphError("Don't support TensorVariable delattr") + @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): if isinstance(value, (paddle.Tensor, MetaInfo)): @@ -527,12 +540,12 @@ class ObjectVariable(VariableBase): tracker(Tracker): The Tracker object that tracks the information of this variable. """ + make_stringify_guard = object_equal_stringify_guard + def __init__(self, obj, graph, tracker): super().__init__(graph, tracker) self.value = obj - make_stringify_guard = object_equal_stringify_guard - @property def main_info(self) -> dict[str, Any]: return {"value": self.value} @@ -619,6 +632,12 @@ def _reconstruct(self, codegen: PyCodeGen): else: super()._reconstruct(codegen) + def setattr(self, key, val): + raise BreakGraphError("Don't support SliceVariable setattr") + + def delattr(self, key): + raise BreakGraphError("Don't support SliceVariable delattr") + @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): if isinstance(value, slice): diff --git a/tests/test_side_effects.py b/tests/test_side_effects.py index 1142aa109..fc2b5e897 100644 --- a/tests/test_side_effects.py +++ b/tests/test_side_effects.py @@ -5,6 +5,7 @@ from test_case_base import TestCaseBase, strict_mode_guard import paddle +import sot from sot import symbolic_translate from sot.utils import InnerError @@ -123,14 +124,39 @@ def slice_in_for_loop(x, iter_num=3): # TODO: Object SideEffect class CustomObject: def __init__(self): - self.x = 0 + self.x = 2 + self.y = paddle.to_tensor(1) + def object_attr_set2(self, x): + self.outputs = [] + self.outputs.append(x) + return self.outputs -def object_attr(cus_obj, t): + +@sot.psdb.check_no_breakgraph +def object_attr_set(cus_obj, t): """object side effect.""" t = t + 1 cus_obj.x = t - return t, cus_obj + return t, cus_obj.x + + +def object_attr_breakgraph(cus_obj, t): + t = t + 1 + sot.psdb.breakgraph() + cus_obj.x = t + sot.psdb.breakgraph() + return t, cus_obj.x + + +@sot.psdb.check_no_breakgraph +def object_attr_tensor_del(cus_obj): + del cus_obj.y + + +@sot.psdb.check_no_breakgraph +def object_attr_int_del(cus_obj): + del cus_obj.x def slice_list_after_change(l): @@ -252,5 +278,42 @@ def test_slice_list_after_change(self): ) +class TestAttrSideEffect(TestCaseBase): + def attr_check(self, func, attr_keys: list[str], cls, *inputs): + cus_obj1 = cls() + cus_obj2 = cls() + sym_output = symbolic_translate(func)(cus_obj1, *inputs) + paddle_output = func(cus_obj2, *inputs) + for key in attr_keys: + self.assert_nest_match( + getattr(cus_obj1, key, f"__MISS_KEY__{key}"), + getattr(cus_obj2, key, f"__MISS_KEY__{key}"), + ) + self.assert_nest_match(sym_output, paddle_output) + + def test_attr_set(self): + self.attr_check(object_attr_set, ["x"], CustomObject, 5) + self.attr_check( + CustomObject.object_attr_set2, ["outputs"], CustomObject, 6 + ) + self.attr_check( + CustomObject.object_attr_set2, + ["outputs"], + CustomObject, + paddle.to_tensor(5), + ) + self.attr_check( + object_attr_set, ["x"], CustomObject, paddle.to_tensor(5) + ) + + def test_attr_del(self): + self.attr_check(object_attr_tensor_del, ["y"], CustomObject) + self.attr_check(object_attr_int_del, ["x"], CustomObject) + + def test_attr_set_breakgraph(self): + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 100) + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 1000) + + if __name__ == "__main__": unittest.main()