Skip to content

Commit

Permalink
Speeds up datatype constraints tests
Browse files Browse the repository at this point in the history
Makes datatype constraints tests significantly faster by not evaluating
inputs.
  • Loading branch information
pranavm-nvidia committed Nov 4, 2024
1 parent 3a77a99 commit 3ac751b
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tripy/tests/constraints/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
def tensor_builder(init, dtype, namespace):
if init is None:
out = tp.ones(dtype=namespace[dtype], shape=(3, 2))
out.eval()
return out
elif not isinstance(init, tp.Tensor):
return init

out = init
if dtype is not None:
out = tp.cast(out, dtype=namespace[dtype])
out.eval()
# Need to evaluate when casting because we run into MLIR-TRT bugs while deriving upper bounds.
out.eval()
return out


Expand All @@ -47,8 +47,6 @@ def tensor_list_builder(init, dtype, namespace):
out = [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
else:
out = [tp.cast(tens, dtype=namespace[dtype]) for tens in init]
for t in out:
t.eval()
return out


Expand Down Expand Up @@ -132,7 +130,7 @@ def default_builder(init, dtype, namespace):
"pad": {"pad": [(0, 1), (1, 0)]},
"permute": {"perm": [1, 0]},
"prod": {"dim": 0},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"quantize": {"input": tp.ones((3, 2)), "scale": tp.Tensor([1, 1, 1]), "dim": 0},
"repeat": {"repeats": 2, "dim": 0},
"reshape": {"shape": [6]},
"resize": {
Expand Down

0 comments on commit 3ac751b

Please sign in to comment.