Skip to content

Commit

Permalink
[jax_triton] Add user-specified name field to serialized format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557415723
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Aug 16, 2023
1 parent a25b1ba commit ad44ac8
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def triton_kernel_call_lowering(
*array_args,
fn,
scalar_args,
name,
call_name,
out_shapes,
grid,
Expand All @@ -273,6 +274,7 @@ def triton_kernel_call_lowering(
raise NotImplementedError(
"`input_output_aliases` only supported on `jaxlib>=0.3.22")

kernel_call_name = name
args = list(ctx.avals_in)
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
for idx, dtype, v in scalar_args:
Expand Down Expand Up @@ -413,12 +415,15 @@ def prune_configs(configs, named_args):
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
for shape in out_shapes
]

if jaxlib.version.__version_info__ >= (0, 4, 15):
call_proto = kernel_call.to_proto(kernel_call_name, serialized_metadata)
else:
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name=call_name,
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(kernel_call.to_proto(serialized_metadata)),
backend_config=zlib.compress(call_proto),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
operand_output_aliases=dict(input_output_aliases),
Expand All @@ -444,7 +449,8 @@ def triton_call(
kernel: triton.JITFunction,
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
grid: GridOrLambda,
call_name: str = "triton_kernel_call",
name: str = "",
call_name: str = "triton_kernel_call", # TODO(cjfj): Remove this.
num_warps: int = 4,
num_stages: int = 2,
input_output_aliases: Optional[Dict[int, int]] = None,
Expand Down Expand Up @@ -565,6 +571,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
call_name=call_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
Expand Down

0 comments on commit ad44ac8

Please sign in to comment.