Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Fix cache miss when enumerate obj as resume fn param #253

Merged
merged 2 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,12 @@ def _put_inner(self, var):
)

def add_global_guarded_variable(self, variable: VariableBase):
self._global_guarded_variables.append(variable)
if variable not in self._global_guarded_variables:
self._global_guarded_variables.append(variable)

def remove_global_guarded_variable(self, variable: VariableBase):
if variable in self._global_guarded_variables:
self._global_guarded_variables.remove(variable)

def _find_tensor_outputs(
self, outputs: list[VariableBase]
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,7 @@ def FOR_ITER(self, instr):
if not isinstance(
iterator, (SequenceIterVariable, DictIterVariable)
):
self._graph.remove_global_guarded_variable(iterator)
raise BreakGraphError()
backup_iter_idx = iterator.idx
self._inline_call_for_loop(iterator, instr)
Expand Down
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/opcode_inline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,5 @@ def FOR_ITER(self, instr: Instruction):
self._lasti = self.indexof(instr.jump_to)

else:
self._graph.remove_global_guarded_variable(iterator)
raise BreakGraphError("For loop fallback.")
24 changes: 24 additions & 0 deletions tests/test_12_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import paddle
from sot import symbolic_translate
from sot.opcode_translator.executor.opcode_executor import (
InstructionTranslatorCache,
)


def gener():
Expand Down Expand Up @@ -157,6 +160,27 @@ def test_list_comp(self):
self.assert_results(run_list_comp, x)


def for_enumerate_cache(func_list, x):
out = None
for idx, func in enumerate(func_list):
out = func(x[idx])
return out


class TestEnumerateCache(TestCaseBase):
def test_run(self):
func_list = [
paddle.nn.Linear(10, 10),
]
x = [
paddle.randn([5, 10]),
]

out = symbolic_translate(for_enumerate_cache)(func_list, x)
out = symbolic_translate(for_enumerate_cache)(func_list, x)
self.assert_nest_match(InstructionTranslatorCache().translate_count, 4)


if __name__ == "__main__":
with strict_mode_guard(0 if sys.version_info >= (3, 10) else 1):
unittest.main()