Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[SideEffect] add side effects support for ObjectVariable #393

Merged
merged 20 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import psdb # noqa: F401
from .opcode_translator.breakpoint import BM, add_breakpoint, add_event
from .opcode_translator.skip_files import skip_function
from .translate import symbolic_translate
Expand Down
48 changes: 36 additions & 12 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
GlobalDelSideEffectRestorer,
GlobalSetSideEffectRestorer,
ListSideEffectRestorer,
ObjDelSideEffectRestorer,
ObjSetSideEffectRestorer,
SideEffectRestorer,
SideEffects,
)
Expand Down Expand Up @@ -552,24 +554,29 @@ def _find_tensor_outputs(
self.add_global_guarded_variable(output)
# Find Tensor Variables from side effects Variables.
for side_effect_var in self.side_effects.proxy_variables:
if side_effect_var.proxy.has_changed:
if isinstance(side_effect_var, (ListVariable, DictVariable)):
for var in side_effect_var.flatten_items():
if (
isinstance(var.tracker, DummyTracker)
and isinstance(var, TensorVariable)
and side_effect_var.tracker.is_traceable()
):
output_tensors.add(var)
if isinstance(var, GlobalVariable):
for record in var.proxy.records:
if (
isinstance(record, (MutationSet, MutationNew))
and isinstance(
record.value.tracker, DummyTracker
)
and isinstance(record.value, TensorVariable)
):
output_tensors.add(record.value)
else:
if isinstance(side_effect_var, GlobalVariable):
proxy_records = side_effect_var.proxy.records
elif side_effect_var.tracker.is_traceable():
# for attr side effect
proxy_records = side_effect_var.attr_proxy.records
else:
continue
for record in proxy_records:
if isinstance(record, (MutationSet, MutationNew)):
for var in record.value.flatten_items():
if isinstance(
var.tracker, DummyTracker
) and isinstance(var, TensorVariable):
output_tensors.add(var)
# Find Tensor in print_stmts
for print_stmt in self._print_variables:
for var in print_stmt.flatten_items():
Expand Down Expand Up @@ -645,7 +652,24 @@ def restore_side_effects(self, variables: list[VariableBase]):
restorers.append(
GlobalDelSideEffectRestorer(record.key)
)
# TODO: support attribute restore
else:
for record in var.attr_proxy.records[::-1]:
if isinstance(record, (MutationSet, MutationNew)):
restorers.append(
ObjSetSideEffectRestorer(
var,
record.key,
record.value,
)
)
elif isinstance(record, MutationDel):
restorers.append(
ObjDelSideEffectRestorer(
var,
record.key,
)
)

for restorer in restorers:
restorer.pre_gen(self.pycode_gen)
for restorer in restorers[::-1]:
Expand Down
29 changes: 13 additions & 16 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,25 +976,22 @@ def LOAD_METHOD(self, instr: Instruction):
self.stack.push(NullVariable())
self.stack.push(method)

def STORE_ATTR(self, instr):
def STORE_ATTR(self, instr: Instruction):
obj = self.stack.pop()
val = self.stack.pop()
key = self._code.co_names[instr.arg]
if isinstance(obj, TensorVariable):
# support tensor variable store attr, like:
# t.stop_gradient = True
obj.graph.call_tensor_method(
"__setattr__",
obj,
VariableFactory().from_value(
key, self._graph, ConstTracker(key)
),
val,
)
else:
raise BreakGraphError(
f"STORE_ATTR don't support {type(obj)}.{key}={val}"
)
key_var = ConstantVariable.wrap_literal(key, self._graph)
BuiltinVariable(
setattr, self._graph, DummyTracker([obj, key_var, val])
)(obj, key_var, val)

def DELETE_ATTR(self, instr: Instruction):
obj = self.stack.pop()
key = instr.argval
key_var = ConstantVariable.wrap_literal(key, self._graph)
BuiltinVariable(delattr, self._graph, DummyTracker([obj, key_var]))(
obj, key_var
)

def STORE_DEREF(self, instr: Instruction):
if sys.version_info >= (3, 11):
Expand Down
12 changes: 12 additions & 0 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,18 @@ def gen_load_attr(self, name: str):
idx = self._code_options["co_names"].index(name)
self._add_instr("LOAD_ATTR", arg=idx, argval=name)

def gen_store_attr(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
idx = self._code_options["co_names"].index(name)
self._add_instr("STORE_ATTR", arg=idx, argval=name)

def gen_delete_attr(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
idx = self._code_options["co_names"].index(name)
self._add_instr("DELETE_ATTR", arg=idx, argval=name)

def gen_load_method(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
Expand Down
38 changes: 38 additions & 0 deletions sot/opcode_translator/executor/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,41 @@ def pre_gen(self, codegen: PyCodeGen):

def post_gen(self, codegen: PyCodeGen):
codegen.gen_delete_global(self.name)


class ObjSetSideEffectRestorer(SideEffectRestorer):
"""
obj.attr = new_value
"""

def __init__(self, obj: VariableBase, name: str, var: VariableBase):
super().__init__()
self.obj = obj
self.name = name
self.var = var

def pre_gen(self, codegen: PyCodeGen):
# value
self.var.reconstruct(codegen)
# obj
self.obj.reconstruct(codegen)

def post_gen(self, codegen: PyCodeGen):
codegen.gen_store_attr(self.name)


class ObjDelSideEffectRestorer(SideEffectRestorer):
"""
del obj.attr
"""

def __init__(self, obj: VariableBase, name: str):
super().__init__()
self.obj = obj
self.name = name

def pre_gen(self, codegen: PyCodeGen):
self.obj.reconstruct(codegen)

def post_gen(self, codegen: PyCodeGen):
codegen.gen_delete_attr(self.name)
43 changes: 29 additions & 14 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
from .variables import DataVariable, TensorVariable


def add_guard(var: VariableBase):
var.graph.add_global_guarded_variable(var)
return var


def raise_err_handle(error):
def inner(*args, **kwargs):
raise error
Expand All @@ -52,7 +57,7 @@ def inner(*args, **kwargs):
# slice
Dispatcher.register(
slice,
("VariableBase"),
("VariableBase",),
lambda stop: SliceVariable(
slice(stop),
graph=stop.graph,
Expand Down Expand Up @@ -406,26 +411,36 @@ def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableB
Dispatcher.register(
getattr,
("VariableBase", "ConstantVariable", optional("VariableBase")),
lambda var, name, default=None: (
var.graph.add_global_guarded_variable(name),
var.getattr(name.get_py_value(), default),
)[1],
lambda var, name, default=None: var.getattr(
add_guard(name).get_py_value(), default
),
)

# hasattr
Dispatcher.register(
hasattr,
("VariableBase", "ConstantVariable"),
lambda var, name: var.hasattr(add_guard(name).get_py_value()),
)

Dispatcher.register(
delattr,
("VariableBase", "VariableBase"),
lambda var, name: var.delattr(add_guard(name).get_py_value()),
)

Dispatcher.register(
setattr,
("VariableBase", "VariableBase", "VariableBase"),
lambda var, name, value: var.setattr(add_guard(name).get_py_value(), value),
)

# len
Dispatcher.register(
len,
("ContainerVariable | UserDefinedLayerVariable",),
lambda var: var.len(),
)
# hasattr
Dispatcher.register(
hasattr,
("VariableBase", "ConstantVariable"),
lambda var, name: (
var.graph.add_global_guarded_variable(name),
var.hasattr(name.get_py_value()),
)[1],
)

# range
# stop
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .callable import ( # noqa: F401
BuiltinVariable,
CallableVariable,
ClassVariable,
FunctionVariable,
LayerVariable,
MethodVariable,
Expand Down
14 changes: 14 additions & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,20 @@ def getattr(self, name: str, default=None):
)
return result

def setattr(self, key: str, value):
from .basic import ConstantVariable

self.attr_proxy.set(key, value)
self.graph.side_effects.record_proxy_variable(self)
return ConstantVariable.wrap_literal(None, self.graph)

def delattr(self, key: str):
from .basic import ConstantVariable

self.attr_proxy.delete(key)
self.graph.side_effects.record_proxy_variable(self)
return ConstantVariable.wrap_literal(None, self.graph)

def __setitem__(self, key, value):
return self.setitem(key, value)

Expand Down
23 changes: 21 additions & 2 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,19 @@ def getattr(self, name: str, default=None):
else:
raise HasNoAttributeError(f"Unknown Tensor attribute: {name}")

def setattr(self, key, val):
# support tensor variable store attr, like:
# t.stop_gradient = True
self.graph.call_tensor_method(
"__setattr__",
self,
VariableFactory().from_value(key, self.graph, ConstTracker(key)),
val,
)

def delattr(self, key):
raise BreakGraphError("Don't support TensorVariable delattr")

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (paddle.Tensor, MetaInfo)):
Expand All @@ -527,12 +540,12 @@ class ObjectVariable(VariableBase):
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

make_stringify_guard = object_equal_stringify_guard

def __init__(self, obj, graph, tracker):
super().__init__(graph, tracker)
self.value = obj

make_stringify_guard = object_equal_stringify_guard

@property
def main_info(self) -> dict[str, Any]:
return {"value": self.value}
Expand Down Expand Up @@ -619,6 +632,12 @@ def _reconstruct(self, codegen: PyCodeGen):
else:
super()._reconstruct(codegen)

def setattr(self, key, val):
raise BreakGraphError("Don't support SliceVariable setattr")

def delattr(self, key):
raise BreakGraphError("Don't support SliceVariable delattr")

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, slice):
Expand Down
Loading
Loading