Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[SideEffect] add side effects support for ObjectVariable #393

Merged
merged 20 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
GlobalDelSideEffectRestorer,
GlobalSetSideEffectRestorer,
ListSideEffectRestorer,
ObjDelSideEffectRestorer,
ObjSetSideEffectRestorer,
SideEffectRestorer,
SideEffects,
)
Expand All @@ -43,6 +45,7 @@
GlobalVariable,
ListVariable,
NullVariable,
ObjectVariable,
PaddleLayerVariable,
TensorVariable,
VariableBase,
Expand Down Expand Up @@ -555,7 +558,7 @@ def _find_tensor_outputs(
var, TensorVariable
):
output_tensors.add(var)
if isinstance(var, GlobalVariable):
if isinstance(var, (GlobalVariable, ObjectVariable)):
for record in var.proxy.records:
if (
isinstance(record, (MutationSet, MutationNew))
Expand Down Expand Up @@ -641,6 +644,24 @@ def restore_side_effects(self, variables: list[VariableBase]):
GlobalDelSideEffectRestorer(record.key)
)
# TODO: support attribute restore
elif isinstance(var, ObjectVariable):
for record in var.proxy.records[::-1]:
if isinstance(record, (MutationSet, MutationNew)):
restorers.append(
ObjSetSideEffectRestorer(
var,
record.key,
record.value,
)
)
elif isinstance(record, MutationDel):
restorers.append(
ObjDelSideEffectRestorer(
var,
record.key,
)
)

for restorer in restorers:
restorer.pre_gen(self.pycode_gen)
for restorer in restorers[::-1]:
Expand Down
10 changes: 6 additions & 4 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def LOAD_METHOD(self, instr: Instruction):
self.stack.push(NullVariable())
self.stack.push(method)

def STORE_ATTR(self, instr):
def STORE_ATTR(self, instr: Instruction):
obj = self.stack.pop()
val = self.stack.pop()
key = self._code.co_names[instr.arg]
Expand All @@ -985,9 +985,11 @@ def STORE_ATTR(self, instr):
val,
)
else:
raise BreakGraphError(
f"STORE_ATTR don't support {type(obj)}.{key}={val}"
)
obj.setattr(key, val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接用 BuiltinVariable dispatch 一下就行了,就没有 if-else 了


def DELETE_ATTR(self, instr: Instruction):
obj = self.stack.pop()
obj.delattr(instr.argval)

def STORE_DEREF(self, instr):
namemap = self._code.co_cellvars + self._code.co_freevars
Expand Down
12 changes: 12 additions & 0 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,18 @@ def gen_load_attr(self, name: str):
idx = self._code_options["co_names"].index(name)
self._add_instr("LOAD_ATTR", arg=idx, argval=name)

def gen_store_attr(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
idx = self._code_options["co_names"].index(name)
self._add_instr("STORE_ATTR", arg=idx, argval=name)

def gen_delete_attr(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
idx = self._code_options["co_names"].index(name)
self._add_instr("DELETE_ATTR", arg=idx, argval=name)

def gen_load_method(self, name: str):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
Expand Down
58 changes: 57 additions & 1 deletion sot/opcode_translator/executor/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar

from .mutable_data import MutableData
from .variables import VariableBase
from .variables import ObjectVariable, VariableBase

if TYPE_CHECKING:
from .mutable_data import DataGetter
Expand Down Expand Up @@ -180,3 +180,59 @@ def pre_gen(self, codegen: PyCodeGen):

def post_gen(self, codegen: PyCodeGen):
codegen.gen_delete_global(self.name)


class ObjSetSideEffectRestorer(SideEffectRestorer):
"""
class CustomObject:
def __init__(self):
self.x = 0

def attr_set(cus_obj):
cus_obj.x = 2

cus_obj = CustomObject()
attr_set(cus_obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和其他示例一样,只写生成代码

obj.attr = new_value

"""

def __init__(self, obj: ObjectVariable, name: str, var: VariableBase):
super().__init__()
self.obj = obj
self.name = name
self.var = var

def pre_gen(self, codegen: PyCodeGen):
self.var.reconstruct(codegen)

def post_gen(self, codegen: PyCodeGen):
# TODO(gouzil): Find this name, doc: cus_obj
codegen.gen_load_fast(self.obj.tracker.name)
codegen.gen_store_attr(self.name)


class ObjDelSideEffectRestorer(SideEffectRestorer):
"""
class CustomObject:
def __init__(self):
self.x = 0

def attr_set(cus_obj):
del cus_obj.x

cus_obj = CustomObject()
attr_set(cus_obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

"""

def __init__(self, obj: ObjectVariable, name: str):
super().__init__()
self.obj = obj
self.name = name

def pre_gen(self, codegen: PyCodeGen):
# do nothing
...

def post_gen(self, codegen: PyCodeGen):
# TODO(gouzil): Find this name, doc: cus_obj
codegen.gen_load_fast(self.obj.tracker.name)
codegen.gen_delete_attr(self.name)
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def inner(*args, **kwargs):
# slice
Dispatcher.register(
slice,
("VariableBase"),
("VariableBase",),
lambda stop: SliceVariable(
slice(stop),
graph=stop.graph,
Expand Down
14 changes: 14 additions & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,20 @@ def getattr(self, name: str, default=None):
)
return result

def setattr(self, key: str, value):
from .basic import ConstantVariable

self.attr_proxy.set(key, value)
self.graph.side_effects.record_proxy_variable(self)
return ConstantVariable.wrap_literal(None, self.graph)

def delattr(self, key: str):
from .basic import ConstantVariable

self.attr_proxy.delete(key)
self.graph.side_effects.record_proxy_variable(self)
return ConstantVariable.wrap_literal(None, self.graph)

def __setitem__(self, key, value):
return self.setitem(key, value)

Expand Down
17 changes: 15 additions & 2 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def getattr(self, name: str, default=None):
else:
raise HasNoAttributeError(f"Unknown Tensor attribute: {name}")

def setattr(self, key, val):
raise BreakGraphError("Don't support TensorVariable setattr")

def delattr(self, key):
raise BreakGraphError("Don't support TensorVariable delattr")

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, (paddle.Tensor, MetaInfo)):
Expand All @@ -527,11 +533,12 @@ class ObjectVariable(VariableBase):
tracker(Tracker): The Tracker object that tracks the information of this variable.
"""

make_stringify_guard = object_equal_stringify_guard

def __init__(self, obj, graph, tracker):
super().__init__(graph, tracker)
self.value = obj

make_stringify_guard = object_equal_stringify_guard
self.proxy = self.attr_proxy

@property
def main_info(self) -> dict[str, Any]:
Expand Down Expand Up @@ -619,6 +626,12 @@ def _reconstruct(self, codegen: PyCodeGen):
else:
super()._reconstruct(codegen)

def setattr(self, key, val):
raise BreakGraphError("Don't support SliceVariable setattr")

def delattr(self, key):
raise BreakGraphError("Don't support SliceVariable delattr")

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, slice):
Expand Down
2 changes: 2 additions & 0 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def __init__(
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker
):
super().__init__(layer, graph, tracker)
self.proxy = self.attr_proxy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是为什么?proxyattr_proxy 不是一种东西


def call_function(self, /, *args, **kwargs):
fn_var = UserDefinedFunctionVariable(
Expand Down Expand Up @@ -634,6 +635,7 @@ class ClassVariable(CallableVariable):
def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker):
super().__init__(graph, tracker)
self.value = class_
self.proxy = self.attr_proxy

def get_py_value(self, allow_tensor=False):
return self.value
Expand Down
28 changes: 26 additions & 2 deletions tests/test_side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,15 @@ def __init__(self):
self.x = 0


def object_attr(cus_obj, t):
def object_attr_set(cus_obj, t):
"""object side effect."""
t = t + 1
cus_obj.x = t
return t, cus_obj
return t, cus_obj.x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测太少了,多加一些吧



def object_attr_del(cus_obj):
del cus_obj.x


def slice_list_after_change(l):
Expand Down Expand Up @@ -252,5 +256,25 @@ def test_slice_list_after_change(self):
)


class TestATTRSideEffect(TestCaseBase):
def attr_check(self, func, attr_keys: list[str], obj, *inputs):
cus_obj1 = obj()
cus_obj2 = obj()
sym_output = symbolic_translate(func)(cus_obj1, *inputs)
paddle_output = func(cus_obj2, *inputs)
for key in attr_keys:
self.assert_nest_match(
getattr(cus_obj1, key, "Key does not exist"),
getattr(cus_obj2, key, "Key does not exist"),
)
self.assert_nest_match(sym_output, paddle_output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样最好加一个机制,在跑动态图和 SOT 前都先 store 一个 attr,然后跑完就恢复

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有 demo 嘛,目前是只验证传入的那个修改的 key 的值


def test_attr_set(self):
self.attr_check(object_attr_set, ["x"], CustomObject, 5)

def test_attr_del(self):
self.attr_check(object_attr_del, ["x"], CustomObject)


if __name__ == "__main__":
unittest.main()
Loading