Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 9, 2024
1 parent 75f1abb commit 4986057
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == Sequence[tp.Tensor]
assert signature.return_annotation == tp.Tensor

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
10 changes: 2 additions & 8 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, executable, arg_names):
for name in self._arg_names:
params.append(inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Tensor))

return_annotation = Tensor if self._executable_signature.get_num_output_args() == 1 else Sequence[Tensor]
return_annotation = Tensor if self._executable_signature.get_num_results() == 1 else Sequence[Tensor]

self.__signature__ = inspect.Signature(params, return_annotation=return_annotation)

Expand Down Expand Up @@ -227,15 +227,9 @@ def add(a, b):
print(compiled_add.get_output_info())
"""
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def eval(self) -> runtime.MemRefValue:
compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
# Ensure that session and client are available as long as tensor lives.
self.executor = Executor(executable)
executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = self.executor.execute()
self.executor.stream.synchronize()
data = executor.execute()
executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]

Expand Down

0 comments on commit 4986057

Please sign in to comment.