Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
guard all outputs (#148)
Browse files Browse the repository at this point in the history
* change project links to `PaddlePaddle/PaddleSOT`

* bump black to 2023 style

* change project name

* `symbolic_opcode_translator` -> `sot`

* guard all outputs
  • Loading branch information
SigureMo authored Jun 12, 2023
1 parent 947e836 commit 11994fc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,6 @@ def _find_tensor_outputs(
output.tracker, DummyTracker
):
output_tensors.append(output)
else:
self.add_global_guarded_variable(output)
return output_tensors
49 changes: 49 additions & 0 deletions tests/test_guard_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import unittest

from test_case_base import (
TestCaseBase,
test_instruction_translator_cache_context,
)

import paddle


def non_operator_related_fn(x: int, y: int):
return x + y


def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int):
a = x + y
return [a, z + z]


class TestGuardOutputs(TestCaseBase):
def test_non_operator_related_fn(self):
with test_instruction_translator_cache_context() as ctx:
self.assert_results(non_operator_related_fn, 1, 2)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(non_operator_related_fn, 3, 4)
self.assertEqual(ctx.translate_count, 2)

def test_partial_non_operator_related_fn(self):
with test_instruction_translator_cache_context() as ctx:
self.assert_results(
partial_non_operator_related_fn,
paddle.to_tensor(1),
paddle.to_tensor(2),
3,
)
self.assertEqual(ctx.translate_count, 1)
self.assert_results(
partial_non_operator_related_fn,
paddle.to_tensor(4),
paddle.to_tensor(5),
6,
)
self.assertEqual(ctx.translate_count, 2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 11994fc

Please sign in to comment.