Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Only skip Sequential forward in container.py (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Jul 19, 2023
1 parent 6a6ff40 commit f76b248
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 8 deletions.
4 changes: 3 additions & 1 deletion sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def _reconstruct(self, codegen: PyCodeGen):
"""
Abstract method to construct an opcode and append it into codegen.instructions
"""
raise NotImplementException()
raise NotImplementException(
'VariableBase._reconstruct() do not implement'
)

def flatten_items(self) -> list[VariableBase]:
"""
Expand Down
11 changes: 11 additions & 0 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,17 @@ def main_info(self) -> dict[str, Any]:
def get_py_value(self, allow_tensor=False):
return self.value

def _reconstruct(self, codegen: PyCodeGen):
# TODO(dev): Consider the case where there are tensors in the slice
if all(
isinstance(x, int) or x is None
for x in [self.value.start, self.value.stop, self.value.step]
):
self.graph.add_global_guarded_variable(self)
codegen.gen_load_const(self.value)
else:
super()._reconstruct(codegen)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if isinstance(value, slice):
Expand Down
8 changes: 6 additions & 2 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ def init_value(self):
return self.value

def get_items(self) -> list[VariableBase]:
raise NotImplementException()
raise NotImplementException(
'ContainerVariable.get_items do not implement'
)

def __len__(self):
raise NotImplementException()
raise NotImplementException(
'ContainerVariable.__len__ do not implement'
)

def len(self):
return VariableFactory.from_value(
Expand Down
11 changes: 6 additions & 5 deletions sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import numpy
import setuptools

import paddle

from ..utils import log


Expand Down Expand Up @@ -111,11 +113,10 @@ def _module_dir(m: types.ModuleType):
f"^({'|'.join(map(re.escape, skip_file_names))})"
)

no_skip_file_names = {paddle_path + 'nn/layer/container.py'}


customed_skip_code = set()

no_skip_code = {paddle.nn.Sequential.forward.__code__}


def need_skip_path(filepath: str) -> bool:
"""
Expand All @@ -127,8 +128,6 @@ def need_skip_path(filepath: str) -> bool:
Returns:
bool: True if the file should be skipped.
"""
if filepath in no_skip_file_names:
return False
if not filepath.startswith("<"):
filepath = os.path.abspath(filepath)
return bool(skip_file_name_re.match(filepath))
Expand All @@ -139,6 +138,8 @@ def skip_function(function):


def need_skip(pycode):
if pycode in no_skip_code:
return False
if pycode in customed_skip_code:
log(3, f"Skip frame by code: {pycode}")
return True
Expand Down
21 changes: 21 additions & 0 deletions tests/test_15_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,26 @@ def test_tensor_subscript_ellipsis(self):
self.assert_results(tensor_subscript_ellipsis, x, y)


class LayerListNet(paddle.nn.Layer):
def __init__(self) -> None:
super().__init__()
self.layer_list = paddle.nn.LayerList(
[paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)]
)

def forward(self, x):
out = self.layer_list[0](x)
for layer in self.layer_list[1:]:
out = layer(out)
return out


class TestLayerListSlice(TestCaseBase):
def test_layer_list_slice(self):
x = paddle.randn([2, 5])
net = LayerListNet()
self.assert_results(layer_list_slice, net, x)


if __name__ == "__main__":
unittest.main()

0 comments on commit f76b248

Please sign in to comment.