Skip to content

Commit

Permalink
limit custom comparison to nvfuser executor
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Dec 12, 2024
1 parent 510d925 commit bfe98e3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,7 @@ def hardshrink_singularity_fn_producer(sample: SampleInput):
ltorch.tanhshrink,
dtypes=(datatypes.inexact,),
sample_input_generator=elementwise_unary_generator,
torch_reference=torch.nn.functional.tanhshrink,
torch_reference=_elementwise_unary_torch(torch.nn.functional.tanhshrink),
test_directives=(
# Torch doesn't support CPU float16 or complex32 tanhshrink
DecorateInfo(
Expand All @@ -1819,6 +1819,7 @@ def hardshrink_singularity_fn_producer(sample: SampleInput):
),
DecorateInfo(
custom_comparator(partial(assert_close, atol=1e-2, rtol=1e-2)),
executors=("nvfuser",),
dtypes=(
datatypes.float16,
datatypes.bfloat16,
Expand Down

0 comments on commit bfe98e3

Please sign in to comment.