From f76b2484e1077abb5d18622227900b29f4818801 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Wed, 19 Jul 2023 17:40:07 +0800 Subject: [PATCH] Only skip Sequential forward in container.py (#281) --- .../executor/variables/base.py | 4 +++- .../executor/variables/basic.py | 11 ++++++++++ .../executor/variables/container.py | 8 +++++-- sot/opcode_translator/skip_files.py | 11 +++++----- tests/test_15_slice.py | 21 +++++++++++++++++++ 5 files changed, 47 insertions(+), 8 deletions(-) diff --git a/sot/opcode_translator/executor/variables/base.py b/sot/opcode_translator/executor/variables/base.py index c672133ff..3a00dfbfe 100644 --- a/sot/opcode_translator/executor/variables/base.py +++ b/sot/opcode_translator/executor/variables/base.py @@ -338,7 +338,9 @@ def _reconstruct(self, codegen: PyCodeGen): """ Abstract method to construct an opcode and append it into codegen.instructions """ - raise NotImplementException() + raise NotImplementException( + 'VariableBase._reconstruct() do not implement' + ) def flatten_items(self) -> list[VariableBase]: """ diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index f530924d6..6512d3eb1 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -512,6 +512,17 @@ def main_info(self) -> dict[str, Any]: def get_py_value(self, allow_tensor=False): return self.value + def _reconstruct(self, codegen: PyCodeGen): + # TODO(dev): Consider the case where there are tensors in the slice + if all( + isinstance(x, int) or x is None + for x in [self.value.start, self.value.stop, self.value.step] + ): + self.graph.add_global_guarded_variable(self) + codegen.gen_load_const(self.value) + else: + super()._reconstruct(codegen) + @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): if isinstance(value, slice): diff --git a/sot/opcode_translator/executor/variables/container.py b/sot/opcode_translator/executor/variables/container.py index a0c44ce5c..01cda83b2 100644 --- a/sot/opcode_translator/executor/variables/container.py +++ b/sot/opcode_translator/executor/variables/container.py @@ -29,10 +29,14 @@ def init_value(self): return self.value def get_items(self) -> list[VariableBase]: - raise NotImplementException() + raise NotImplementException( + 'ContainerVariable.get_items do not implement' + ) def __len__(self): - raise NotImplementException() + raise NotImplementException( + 'ContainerVariable.__len__ do not implement' + ) def len(self): return VariableFactory.from_value( diff --git a/sot/opcode_translator/skip_files.py b/sot/opcode_translator/skip_files.py index 2b8983806..f05f3ef81 100644 --- a/sot/opcode_translator/skip_files.py +++ b/sot/opcode_translator/skip_files.py @@ -40,6 +40,8 @@ import numpy import setuptools +import paddle + from ..utils import log @@ -111,11 +113,10 @@ def _module_dir(m: types.ModuleType): f"^({'|'.join(map(re.escape, skip_file_names))})" ) -no_skip_file_names = {paddle_path + 'nn/layer/container.py'} - - customed_skip_code = set() +no_skip_code = {paddle.nn.Sequential.forward.__code__} + def need_skip_path(filepath: str) -> bool: """ @@ -127,8 +128,6 @@ def need_skip_path(filepath: str) -> bool: Returns: bool: True if the file should be skipped. """ - if filepath in no_skip_file_names: - return False if not filepath.startswith("<"): filepath = os.path.abspath(filepath) return bool(skip_file_name_re.match(filepath)) @@ -139,6 +138,8 @@ def skip_function(function): def need_skip(pycode): + if pycode in no_skip_code: + return False if pycode in customed_skip_code: log(3, f"Skip frame by code: {pycode}") return True diff --git a/tests/test_15_slice.py b/tests/test_15_slice.py index 8f5d59563..d9ff77d9f 100644 --- a/tests/test_15_slice.py +++ b/tests/test_15_slice.py @@ -85,5 +85,26 @@ def test_tensor_subscript_ellipsis(self): self.assert_results(tensor_subscript_ellipsis, x, y) +class LayerListNet(paddle.nn.Layer): + def __init__(self) -> None: + super().__init__() + self.layer_list = paddle.nn.LayerList( + [paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)] + ) + + def forward(self, x): + out = self.layer_list[0](x) + for layer in self.layer_list[1:]: + out = layer(out) + return out + + +class TestLayerListSlice(TestCaseBase): + def test_layer_list_slice(self): + x = paddle.randn([2, 5]) + net = LayerListNet() + self.assert_results(layer_list_slice, net, x) + + if __name__ == "__main__": unittest.main()