Skip to content

Commit

Permalink
Sort ops.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569272845
  • Loading branch information
shaobohou authored and TF2JAXDev committed Sep 28, 2023
1 parent bf16e51 commit 53016eb
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,20 @@ def wrapped(proto):


_jax_ops = {
# go/keep-sorted start
"Abs": _get_jax_op(jnp.abs, {"T"}),
"Acosh": _get_jax_op(jnp.arccosh, {"T"}),
"Add": _get_jax_op(anp.add, {"T"}),
"AddN": _get_jax_op(
lambda *args: anp.sum_(anp.stack(args, axis=0), axis=0, keepdims=False),
{"T", "N"}),
"AddV2": _get_jax_op(anp.add, {"T"}),
"Angle": _get_jax_op(jnp.angle, {"T", "Tout"}),
"ArgMax": _get_jax_op(jnp.argmax, {"T", "Tidx", "output_type"}),
"ArgMin": _get_jax_op(jnp.argmin, {"T", "Tidx", "output_type"}),
"Acosh": _get_jax_op(jnp.arccosh, {"T"}),
"Angle": _get_jax_op(jnp.angle, {"T", "Tout"}),
"Asinh": _get_jax_op(jnp.arcsinh, {"T"}),
"Atanh": _get_jax_op(jnp.arctanh, {"T"}),
"Atan2": _get_jax_op(jnp.arctan2, {"T"}),
"Atanh": _get_jax_op(jnp.arctanh, {"T"}),
"BesselI0e": _get_jax_op(jax.lax.bessel_i0e, {"T"}),
"BesselI1e": _get_jax_op(jax.lax.bessel_i1e, {"T"}),
"BitwiseAnd": _get_jax_op(jnp.bitwise_and, {"T"}),
Expand Down Expand Up @@ -109,11 +110,10 @@ def wrapped(proto):
"FFT3D": _get_jax_op(
functools.partial(jnp.fft.fftn, axes=(-3, -2, -1,)), {"Tcomplex"}),
"Floor": _get_jax_op(jnp.floor, {"T"}),
"FloorMod": _get_jax_op(anp.mod, {"T"}),
"FloorDiv": _get_jax_op(anp.floor_divide, {"T"}),
"FloorMod": _get_jax_op(anp.mod, {"T"}),
"Greater": _get_jax_op(anp.greater, {"T"}),
"GreaterEqual": _get_jax_op(anp.greater_equal, {"T"}),
"Identity": _get_jax_op(lambda x: x, {"T"}),
"IFFT": _get_jax_op(
functools.partial(jnp.fft.ifftn, axes=(-1,)), {"Tcomplex"}),
"IFFT2D": _get_jax_op(
Expand All @@ -128,12 +128,13 @@ def wrapped(proto):
"IRFFT3D": _get_jax_op(
functools.partial(
jnp.fft.irfftn, axes=(-3, -2, -1,)), {"Tcomplex", "Treal"}),
"Identity": _get_jax_op(lambda x: x, {"T"}),
"Igamma": _get_jax_op(jax.lax.igamma, {"T"}),
"Igammac": _get_jax_op(jax.lax.igammac, {"T"}),
"Imag": _get_jax_op(jax.lax.imag, {"T", "Tout"}),
"IsFinite": _get_jax_op(jnp.isfinite, {"T"}),
"Invert": _get_jax_op(jnp.bitwise_not, {"T"}),
"InvertPermutation": _get_jax_op(anp.invert_permutation, {"T"}),
"IsFinite": _get_jax_op(jnp.isfinite, {"T"}),
"L2Loss": _get_jax_op(lambda x: 0.5 * jnp.sum(jnp.square(x)), {"T"}),
"LeftShift": _get_jax_op(jnp.left_shift, {"T"}),
"Less": _get_jax_op(anp.less, {"T", "incompatible_shape_error"}),
Expand All @@ -144,23 +145,15 @@ def wrapped(proto):
"LogicalAnd": _get_jax_op(jnp.logical_and, {"T"}),
"LogicalNot": _get_jax_op(jnp.logical_not, {"T"}),
"LogicalOr": _get_jax_op(jnp.logical_or, {"T"}),
"Minimum": _get_jax_op(anp.minimum, {"T"}),
"Maximum": _get_jax_op(anp.maximum, {"T"}),
"Minimum": _get_jax_op(anp.minimum, {"T"}),
"Mul": _get_jax_op(anp.multiply, {"T"}),
"Neg": _get_jax_op(anp.negative, {"T"}),
"NoOp": _get_jax_op(lambda: _EMPTY_RETURN_VALUE, set({})),
"NotEqual": _get_jax_op(anp.not_equal, {"T", "incompatible_shape_error"}),
"OnesLike": _get_jax_op(jnp.ones_like, {"T"}),
"PopulationCount": _get_jax_op(jax.lax.population_count, {"T"}),
"Pow": _get_jax_op(anp.power, {"T"}),
"Rank": _get_jax_op(lambda x: np.array(jnp.ndim(x)), {"T"}),
"Real": _get_jax_op(jax.lax.real, {"T", "Tout"}),
"ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}),
"RealDiv": _get_jax_op(anp.true_divide, {"T"}),
"Reciprocal": _get_jax_op(anp.reciprocal, {"T"}),
"Relu": _get_jax_op(jax.nn.relu, {"T"}),
"Relu6": _get_jax_op(jax.nn.relu6, {"T"}),
"ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}),
"RFFT": _get_jax_op(
functools.partial(jnp.fft.rfftn, axes=(-1,)), {"Tcomplex", "Treal"}),
"RFFT2D": _get_jax_op(
Expand All @@ -169,6 +162,14 @@ def wrapped(proto):
"RFFT3D": _get_jax_op(
functools.partial(
jnp.fft.rfftn, axes=(-3, -2, -1,)), {"Tcomplex", "Treal"}),
"Rank": _get_jax_op(lambda x: np.array(jnp.ndim(x)), {"T"}),
"ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}),
"Real": _get_jax_op(jax.lax.real, {"T", "Tout"}),
"RealDiv": _get_jax_op(anp.true_divide, {"T"}),
"Reciprocal": _get_jax_op(anp.reciprocal, {"T"}),
"Relu": _get_jax_op(jax.nn.relu, {"T"}),
"Relu6": _get_jax_op(jax.nn.relu6, {"T"}),
"ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}),
"RightShift": _get_jax_op(jnp.right_shift, {"T"}),
"Round": _get_jax_op(jnp.round, {"T"}),
"Rsqrt": _get_jax_op(jax.lax.rsqrt, {"T"}),
Expand Down Expand Up @@ -203,6 +204,7 @@ def wrapped(proto):
{"T", "Tindices", "Tnumsegments"}),
"Where": _get_jax_op(jnp.argwhere, {"T"}),
"ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}),
# go/keep-sorted end
# The assignment logic is handled in _OpNode and convert().
"AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}),
"AssignSubVariableOp": _get_jax_op(jnp.subtract, {"dtype"}),
Expand Down

0 comments on commit 53016eb

Please sign in to comment.