-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation complex dtype relu6 #23599
Activation complex dtype relu6 #23599
Conversation
…I, what remains is the data_classes and stateful
If you are working on an open task, please edit the PR description to link to the issue you've created. For more information, please check ToDo List Issues Guide. Thank you 🤗 |
@@ -38,8 +38,8 @@ def thresholded_relu( | |||
return tf.cast(tf.where(x > threshold, x, 0), x.dtype) | |||
|
|||
|
|||
@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) | |||
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: | |||
# @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line should probably be removed completely rather than just commented out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
The failed tests seem to be from the tensorflow frontend trying to test with complex numbers, which causes errors in the ground truth. You should probably add complex as an unsupported dtype to specifically the tensorflow frontend function, and change the thing I mentioned above, and then it'll be good to merge |
Added support for complex dtype for
relu6
, as well as_relu6_jax_like
function