diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index a8b0760ac..21f47e8e2 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -271,4 +271,6 @@ def _find_tensor_outputs( output.tracker, DummyTracker ): output_tensors.append(output) + else: + self.add_global_guarded_variable(output) return output_tensors diff --git a/tests/test_guard_outputs.py b/tests/test_guard_outputs.py new file mode 100644 index 000000000..6c569d115 --- /dev/null +++ b/tests/test_guard_outputs.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def non_operator_related_fn(x: int, y: int): + return x + y + + +def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + y + return [a, z + z] + + +class TestGuardOutputs(TestCaseBase): + def test_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results(non_operator_related_fn, 1, 2) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(non_operator_related_fn, 3, 4) + self.assertEqual(ctx.translate_count, 2) + + def test_partial_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(1), + paddle.to_tensor(2), + 3, + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(4), + paddle.to_tensor(5), + 6, + ) + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main()