-
Notifications
You must be signed in to change notification settings - Fork 26
[SideEffect] add side effects support for ObjectVariable #393
Changes from 7 commits
d7a1ad2
962720f
7d67941
b6cd0e5
ebdbfb0
17622cb
db66dbe
f33f00a
a134040
bcec675
2c2fae0
446105c
44f9b99
719f89e
cb9a4cd
25b6d73
842f5c2
4431dab
3cdcad7
34220de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar | ||
|
||
from .mutable_data import MutableData | ||
from .variables import VariableBase | ||
from .variables import ObjectVariable, VariableBase | ||
|
||
if TYPE_CHECKING: | ||
from .mutable_data import DataGetter | ||
|
@@ -180,3 +180,59 @@ def pre_gen(self, codegen: PyCodeGen): | |
|
||
def post_gen(self, codegen: PyCodeGen): | ||
codegen.gen_delete_global(self.name) | ||
|
||
|
||
class ObjSetSideEffectRestorer(SideEffectRestorer): | ||
""" | ||
class CustomObject: | ||
def __init__(self): | ||
self.x = 0 | ||
|
||
def attr_set(cus_obj): | ||
cus_obj.x = 2 | ||
|
||
cus_obj = CustomObject() | ||
attr_set(cus_obj) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 和其他示例一样,只写生成代码
|
||
""" | ||
|
||
def __init__(self, obj: ObjectVariable, name: str, var: VariableBase): | ||
super().__init__() | ||
self.obj = obj | ||
self.name = name | ||
self.var = var | ||
|
||
def pre_gen(self, codegen: PyCodeGen): | ||
self.var.reconstruct(codegen) | ||
|
||
def post_gen(self, codegen: PyCodeGen): | ||
# TODO(gouzil): Find this name, doc: cus_obj | ||
codegen.gen_load_fast(self.obj.tracker.name) | ||
codegen.gen_store_attr(self.name) | ||
|
||
|
||
class ObjDelSideEffectRestorer(SideEffectRestorer): | ||
""" | ||
class CustomObject: | ||
def __init__(self): | ||
self.x = 0 | ||
|
||
def attr_set(cus_obj): | ||
del cus_obj.x | ||
|
||
cus_obj = CustomObject() | ||
attr_set(cus_obj) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
""" | ||
|
||
def __init__(self, obj: ObjectVariable, name: str): | ||
super().__init__() | ||
self.obj = obj | ||
self.name = name | ||
|
||
def pre_gen(self, codegen: PyCodeGen): | ||
# do nothing | ||
... | ||
|
||
def post_gen(self, codegen: PyCodeGen): | ||
# TODO(gouzil): Find this name, doc: cus_obj | ||
codegen.gen_load_fast(self.obj.tracker.name) | ||
codegen.gen_delete_attr(self.name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,6 +414,7 @@ def __init__( | |
self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker | ||
): | ||
super().__init__(layer, graph, tracker) | ||
self.proxy = self.attr_proxy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是为什么? |
||
|
||
def call_function(self, /, *args, **kwargs): | ||
fn_var = UserDefinedFunctionVariable( | ||
|
@@ -634,6 +635,7 @@ class ClassVariable(CallableVariable): | |
def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): | ||
super().__init__(graph, tracker) | ||
self.value = class_ | ||
self.proxy = self.attr_proxy | ||
|
||
def get_py_value(self, allow_tensor=False): | ||
return self.value | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,11 +126,15 @@ def __init__(self): | |
self.x = 0 | ||
|
||
|
||
def object_attr(cus_obj, t): | ||
def object_attr_set(cus_obj, t): | ||
"""object side effect.""" | ||
t = t + 1 | ||
cus_obj.x = t | ||
return t, cus_obj | ||
return t, cus_obj.x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测太少了,多加一些吧 |
||
|
||
|
||
def object_attr_del(cus_obj): | ||
del cus_obj.x | ||
|
||
|
||
def slice_list_after_change(l): | ||
|
@@ -252,5 +256,25 @@ def test_slice_list_after_change(self): | |
) | ||
|
||
|
||
class TestATTRSideEffect(TestCaseBase): | ||
def attr_check(self, func, attr_keys: list[str], obj, *inputs): | ||
cus_obj1 = obj() | ||
cus_obj2 = obj() | ||
sym_output = symbolic_translate(func)(cus_obj1, *inputs) | ||
paddle_output = func(cus_obj2, *inputs) | ||
for key in attr_keys: | ||
self.assert_nest_match( | ||
getattr(cus_obj1, key, "Key does not exist"), | ||
getattr(cus_obj2, key, "Key does not exist"), | ||
) | ||
self.assert_nest_match(sym_output, paddle_output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同样最好加一个机制,在跑动态图和 SOT 前都先 store 一个 attr,然后跑完就恢复 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有 demo 嘛,目前是只验证传入的那个修改的 key 的值 |
||
|
||
def test_attr_set(self): | ||
self.attr_check(object_attr_set, ["x"], CustomObject, 5) | ||
|
||
def test_attr_del(self): | ||
self.attr_check(object_attr_del, ["x"], CustomObject) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用 BuiltinVariable dispatch 一下就行了,就没有 if-else 了