diff --git a/jaxlib/triton/compat.py b/jaxlib/triton/compat.py index 3c2125042636..0e5d948263ad 100644 --- a/jaxlib/triton/compat.py +++ b/jaxlib/triton/compat.py @@ -672,13 +672,13 @@ def max(x: tensor, y: tensor) -> tensor: assert x.shape == y.shape if x.dtype.is_floating(): # TODO(slebedev): Triton promotes bfloat16 to float32 and back here. - return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.type) if not x.dtype.is_int(): raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}") elif x.dtype.is_int_signed(): - return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxsi(x.handle, y.handle), x.type) else: - return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxui(x.handle, y.handle), x.type) @staticmethod def min(x: tensor, y: tensor) -> tensor: @@ -686,13 +686,13 @@ def min(x: tensor, y: tensor) -> tensor: assert x.shape == y.shape if x.dtype.is_floating(): # TODO(slebedev): Triton promotes bfloat16 to float32 and back here. - return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minnumf(x.handle, y.handle), x.type) if not x.dtype.is_int(): raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}") elif x.dtype.is_int_signed(): - return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minsi(x.handle, y.handle), x.type) else: - return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minui(x.handle, y.handle), x.type) sin = libdevice_extern_elementwise({ (float32,): ("__nv_sinf", float32),