From 4986057be69470f1f64eba0a63e18b66399d2f6a Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Wed, 6 Nov 2024 10:41:09 -0800 Subject: [PATCH] Address review comments --- tripy/tests/backend/api/test_executable.py | 2 +- tripy/tripy/backend/api/executable.py | 10 ++-------- tripy/tripy/frontend/tensor.py | 6 +++--- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index 40ac8d01a..3b588b463 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable): assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD assert param.annotation == tp.Tensor - assert signature.return_annotation == Sequence[tp.Tensor] + assert signature.return_annotation == tp.Tensor def test_signature_multiple_return_values(self, multiple_return_executable): signature = inspect.signature(multiple_return_executable) diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index 8c9771531..3f9f2395c 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -49,7 +49,7 @@ def __init__(self, executable, arg_names): for name in self._arg_names: params.append(inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Tensor)) - return_annotation = Tensor if self._executable_signature.get_num_output_args() == 1 else Sequence[Tensor] + return_annotation = Tensor if self._executable_signature.get_num_results() == 1 else Sequence[Tensor] self.__signature__ = inspect.Signature(params, return_annotation=return_annotation) @@ -227,15 +227,9 @@ def add(a, b): print(compiled_add.get_output_info()) """ num_input_args = self._executable_signature.get_num_input_args() - num_output_args = self._executable_signature.get_num_output_args() num_results = self._executable_signature.get_num_results() - assert not (num_output_args and num_results), "Cannot have both output arguments and results" - - if num_output_args: - return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)] - else: - return [self._get_result_info(idx) for idx in range(num_results)] + return [self._get_result_info(idx) for idx in range(num_results)] def save(self, path: str) -> None: """ diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index 28cbb646d..8516c50eb 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -186,11 +186,11 @@ def eval(self) -> runtime.MemRefValue: compiler = Compiler(trt_builder_opt_level=0) executable = compiler.compile(mlir, flat_ir=flat_ir) # Ensure that session and client are available as long as tensor lives. - self.executor = Executor(executable) + executor = Executor(executable) # Upon computing the value of this tensor, we switch it to have a `Storage` # parameter so that it does not need to be computed again. - data = self.executor.execute() - self.executor.stream.synchronize() + data = executor.execute() + executor.stream.synchronize() assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0]