Skip to content

Commit

Permalink
add tanhshrink op (#1522)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored Dec 12, 2024
1 parent bb20d73 commit 6403646
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)
tanhshrink = _register_torch_operation("tanhshrink", module=torch.nn.functional)


def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = False) -> bool:
Expand All @@ -858,6 +859,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.tanhshrink, tanhshrink, checker=_always_executable)

#
# Elementwise binary operations
Expand Down
26 changes: 26 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,32 @@ def hardshrink_singularity_fn_producer(sample: SampleInput):
elementwise_unary_ops.append(selu_opinfo)


tanhshrink_opinfo = OpInfo(
ltorch.tanhshrink,
dtypes=(datatypes.inexact,),
sample_input_generator=elementwise_unary_generator,
torch_reference=_elementwise_unary_torch(torch.nn.functional.tanhshrink),
test_directives=(
# Torch doesn't support CPU float16 or complex32 tanhshrink
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
dtypes=(datatypes.float16, datatypes.complex32),
devicetypes=(devices.DeviceType.CPU,),
),
DecorateInfo(
custom_comparator(partial(assert_close, atol=1e-2, rtol=1e-2)),
executors=("nvfuser",),
dtypes=(
datatypes.float16,
datatypes.bfloat16,
),
),
),
)
elementwise_unary_ops.append(tanhshrink_opinfo)


round_opinfo = OpInfo(
clang.round,
dtypes=(datatypes.floating, datatypes.exact),
Expand Down
7 changes: 7 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,13 @@ def silu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
_inplace_to_out_of_place[silu] = silu, 1


@torchsymbol(torch.nn.functional.tanhshrink)
def tanhshrink(a: TensorLike, /) -> TensorLike:
return a - tanh(a)


_inplace_to_out_of_place[tanhshrink] = tanhshrink, -1

#
# Elementwise binary operations
#
Expand Down
1 change: 0 additions & 1 deletion thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@
torch.nn.functional.softplus,
torch.nn.functional.softshrink,
torch.nn.functional.softsign,
torch.nn.functional.tanhshrink,
torch.nn.functional.triplet_margin_loss,
torch.nn.functional.triplet_margin_with_distance_loss,
torch.nn.functional.unfold,
Expand Down

0 comments on commit 6403646

Please sign in to comment.