From b768624116e4afcefc2551e30ece37e27431598d Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Sat, 3 May 2025 02:48:18 +0000 Subject: [PATCH 01/16] intercept torch.ops.xla.ragged_paged_attention and dispatch to ragged attention v2 implementation --- torchax/torchax/ops/jtorch.py | 45 +- torchax/torchax/ops/ragged_attention.py | 1257 +++++++++++++++++++++++ 2 files changed, 1298 insertions(+), 4 deletions(-) create mode 100644 torchax/torchax/ops/ragged_attention.py diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index f03e5cbf7a00..6249176436d5 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -17,6 +17,7 @@ import torchax.tensor from torchax.view import View, NarrowInfo import torch.utils._pytree as pytree +from torchax.ops.ragged_attention import ragged_paged_attention as ragged_paged_attention_kernel def register_function(torch_func, **kwargs): @@ -508,7 +509,43 @@ def linalg_tensorsolve(A, b, dims=None): @register_function(torch.nn.functional.linear) def functional_linear(self, weights, bias=None): - res = jnp.einsum("...a,ba->...b", self, weights) - if bias is not None: - res += bias - return res + res = jnp.einsum("...a,ba->...b", self, weights) + if bias is not None: + res += bias + return res + +@register_function(torch.ops.xla.ragged_paged_attention) +def _ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + use_kernel: bool = True, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, +): + + return ragged_paged_attention_kernel( + q = q, + kv_pages = kv_pages, + kv_lens = kv_lens, + page_indices = page_indices, + cu_q_lens = cu_q_lens, + num_seqs = num_seqs, + sm_scale = sm_scale, + sliding_window = sliding_window, + soft_cap = soft_cap, + mask_value = mask_value, + num_kv_pages_per_block = num_kv_pages_per_block, + num_queries_per_block = num_queries_per_block, + vmem_limit_bytes = vmem_limit_bytes, +) + + diff --git a/torchax/torchax/ops/ragged_attention.py b/torchax/torchax/ops/ragged_attention.py new file mode 100644 index 000000000000..fdc5a63adc8d --- /dev/null +++ b/torchax/torchax/ops/ragged_attention.py @@ -0,0 +1,1257 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Copied from https://github.com/pytorch/xla/blob/master/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +since the pypi torchxla package does not have the updated function. +""" + +"""TPU-Friendly Ragged Paged Attention kernel. + +This kernel offers a highly optimized implementation of ragged paged attention, +specifically designed for TPU and compatible with a wide range of model +specifications. It supports mixed prefill and decoding, enhancing throughput +during inference. +""" +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 + +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block +TUNED_BLOCK_SIZES = { + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, +} + + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[:-len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_device_name(num_devices: int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + +def get_tuned_block_sizes( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, +) -> tuple[int, int]: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] + sem, + page_indices_ref, # i32[max_num_seqs, pages_per_seq] + offset, # [seq_idx, kv_pages_start] + ): + self._vmem_buf = vmem_buf + seq_id, kv_pages_start = offset + pages_per_seq = page_indices_ref.shape[1] + self._async_copies = [] + # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert + # a bunch of if-ops. Check the performance when we have benchmarking setup. + for i in range(vmem_buf.shape[0]): + page_idx = kv_pages_start + i + page_idx = jax.lax.select(page_idx < pages_per_seq, page_idx, + pages_per_seq - 1) + self._async_copies.append( + pltpu.make_async_copy( + pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], + vmem_buf.at[i], + sem, + )) + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def wait(self): + for async_copy in self._async_copies: + async_copy.wait() + return self._vmem_buf + + +def ref_ragged_paged_attention( + queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1], + *, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, +): + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0 + num_query_per_kv = num_q_heads // num_kv_heads + outputs = [] + for i in range(num_seqs[0]): + q_start = cu_q_lens[i] + q_end = cu_q_lens[i + 1] + q_len = q_end - q_start + kv_len = kv_lens[i] + indices = page_indices[i] + q = queries[q_start:q_end] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, + head_dim)[:kv_len] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, + head_dim)[:kv_len] + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) + attn *= sm_scale + q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( + jnp.int32, attn.shape, 1) + kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) + attn += jnp.where(mask, mask_value, 0.0) + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) + outputs.append(out) + + return jnp.concatenate(outputs, axis=0) + + +# Expect to run these checks during runtime. +def dynamic_validate_inputs( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, +): + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) + max_num_batched_tokens = q.shape[0] + page_size = kv_pages.shape[1] + max_num_seqs, pages_per_seq = page_indices.shape + if num_seqs[0] > max_num_seqs: + raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") + max_kv_len = jnp.max(kv_lens) + min_pages_per_seq = cdiv(max_kv_len, page_size) + if pages_per_seq < min_pages_per_seq: + raise ValueError( + f"{pages_per_seq=} must be greater or equal to" + f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") + if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: + raise ValueError( + f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" + f" {max_num_batched_tokens=}.") + for i in range(num_seqs[0]): + q_len = cu_q_lens[i + 1] - cu_q_lens[i] + kv_len = kv_lens[i] + if q_len > kv_len: + raise ValueError( + f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") + + +# Expect to run these checks during compile time. +def static_validate_inputs( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, +): + _, num_q_heads, head_dim = q.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + max_num_seqs, pages_per_seq = page_indices.shape + if num_seqs.shape != (1,): + raise ValueError(f"{num_seqs.shape=} must be (1,)") + if head_dim_k != head_dim: + raise ValueError( + f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") + if kv_lens.shape != (max_num_seqs,): + raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if cu_q_lens.shape != (max_num_seqs + 1,): + raise ValueError( + f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" + " `max_num_seqs` is `page_indices.shape[0]`.") + if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or + cu_q_lens.dtype != jnp.int32): + raise ValueError( + "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" + f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," + f" {cu_q_lens.dtype=}.") + if num_q_heads % num_kv_heads != 0: + raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") + if (num_kv_pages_per_block is not None and + not 0 < num_kv_pages_per_block <= pages_per_seq): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}].") + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. + + +def ragged_paged_attention_kernel( + # Prefetch + kv_lens_ref, # [max_num_seqs] + page_indices_ref, # [max_num_seqs, pages_per_seq] + cu_q_lens_ref, # [max_num_seqs + 1] + seq_buf_idx_ref, + # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs. + num_seqs_ref, + # Input + q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + # Output + o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + # Scratch + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] + sems, # [2, 2] + l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] + *, + sm_scale: float, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, +): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + num_seqs = num_seqs_ref[0] + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 + num_kv_per_blk = num_kv_pages_per_blk * page_size + num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk + heads_blk_idx, q_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + ) + num_heads_blks = pl.num_programs(0) + init_seq_idx = seq_buf_idx_ref[0] + init_buf_idx = seq_buf_idx_ref[1] + q_len_start = q_blk_idx * num_q_per_blk + q_len_end = q_len_start + num_q_per_blk + + def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, + buf_idx): + offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[:, :, + pl.ds(heads_start, num_combined_kv_heads_per_blk + ), :], + kv_bufs.at[buf_idx], + sems.at[buf_idx], + page_indices_ref, + offset, + ) + return async_copy_kv + + # TODO(jevinjiang): Add these to Mosaic: + # 1. Support arbitrary strided load/store for any dtype. + # 2. Support arbitrary strided load/store for any last dimension. + def strided_load_kv(ref, start, step): + if ref.dtype == jnp.float32: + return ref[start::step, :], ref[start + 1::step, :] + packing = get_dtype_packing(ref.dtype) + assert ref.dtype == jnp.bfloat16 + assert step % packing == 0 + b_start = start // packing + b_step = step // packing + b_ref = ref.bitcast(jnp.uint32) + b = b_ref[b_start::b_step, :] + bk = b << 16 + bv = b & jnp.uint32(0xffff0000) + k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) + v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) + return k, v + + def fold_on_2nd_minor(vec): + assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 + assert len(vec.shape) >= 2 + last_dim = vec.shape[-1] + packing = get_dtype_packing(vec.dtype) + if vec.shape[-2] % packing != 0: + vec = vec.astype(jnp.float32) + return vec.reshape(-1, last_dim) + + @pl.when(heads_blk_idx + q_blk_idx == 0) + def prefetch_first_kv_blk(): + async_copy_kv = create_kv_async_copy_descriptors(heads_blk_idx, + init_seq_idx, 0, + init_buf_idx) + async_copy_kv.start() + + def is_cur_q_blk_needed(q_states): + done, cur_seq_idx, _ = q_states + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) + + def compute_with_cur_q_blk(q_states): + done, cur_seq_idx, cur_buf_idx = q_states + q_start = cu_q_lens_ref[cur_seq_idx] + q_end = cu_q_lens_ref[cur_seq_idx + 1] + q_len = q_end - q_start + kv_len = kv_lens_ref[cur_seq_idx] + + def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx): + next_kv_blk_idx = kv_blk_idx + 1 + is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len + next_kv_blk_idx = lax.select( + is_last_kv_blk, + 0, + next_kv_blk_idx, + ) + is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end + next_seq_idx = lax.select( + is_last_kv_blk, + lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx), + cur_seq_idx, + ) + is_last_seq = next_seq_idx == num_seqs + next_seq_idx = lax.select( + is_last_seq, + 0, + next_seq_idx, + ) + next_heads_blk_idx = lax.select( + is_last_seq, + heads_blk_idx + 1, + heads_blk_idx, + ) + next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0) + return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx + + def flash_attention( + q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim] + k, # [num_kv_per_blk, head_dim] + v, # [num_kv_per_blk, head_dim] + head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + *, + kv_blk_idx, + ): + assert q.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + head_dim, + ) + assert k.shape == ( + num_kv_per_blk, + head_dim, + ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" + assert v.shape == (num_kv_per_blk, head_dim) + assert head_m_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_l_ref.shape == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) + assert head_acc_ref.shape == ( + num_q_per_blk, + num_q_heads_per_kv_head, + head_dim, + ) + kv_len_start = kv_blk_idx * num_kv_per_blk + + def masked_store(ref, val, start, end, group=1): + iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group + mask = jnp.logical_and(iota >= start, iota < end) + pl.store( + ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + + qk = ( + jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * + sm_scale) + store_start = jnp.maximum(q_start - q_len_start, 0) + store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) + + @pl.when(kv_blk_idx == 0) + def init_scratch_ref(): + masked_store( + head_m_ref, + jnp.full_like(head_m_ref, -jnp.inf), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_l_ref, + jnp.zeros_like(head_l_ref), + store_start, + store_end, + num_q_heads_per_kv_head, + ) + masked_store( + head_acc_ref, + jnp.zeros_like(head_acc_ref), + store_start, + store_end, + ) + + row_ids = ((kv_len - q_len) + q_len_start - q_start + + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 0, + ) // num_q_heads_per_kv_head) + col_ids = kv_len_start + jax.lax.broadcasted_iota( + jnp.int32, + (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), + 1, + ) + causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window + >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) + qk += jnp.where(causal_mask, mask_value, 0.0) + m_curr = jnp.max(qk, axis=1, keepdims=True) + s_curr = jnp.exp(qk - m_curr) + qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) + lm_store_shape = head_m_ref.shape + m_curr = jnp.broadcast_to(m_curr, lm_store_shape) + l_curr = jnp.broadcast_to( + s_curr.sum(axis=1, keepdims=True), lm_store_shape) + m_prev = head_m_ref[...] + l_prev = head_l_ref[...] + m_next = jnp.maximum(m_prev, m_curr) + masked_store(head_m_ref, m_next, store_start, store_end, + num_q_heads_per_kv_head) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_alpha = alpha * l_prev + l_next = l_alpha + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + masked_store( + head_l_ref, + l_next_safe, + store_start, + store_end, + num_q_heads_per_kv_head, + ) + + def broadcast_to_shape(arr, shape): + if arr.shape == shape: + return arr + assert len(arr.shape) == len(shape) + assert arr.shape[0] == shape[0] + assert shape[1] % arr.shape[1] == 0 + # no-op concatenation. + return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])], + axis=1) + + o_curr = head_acc_ref[...].reshape(-1, head_dim) + l_alpha = broadcast_to_shape(l_alpha, qkv.shape) + beta = broadcast_to_shape(beta, qkv.shape) + l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) + out = lax.div( + l_alpha * o_curr + beta * qkv, + l_next_safe, + ) + masked_store( + head_acc_ref, + out.reshape(head_acc_ref.shape), + store_start, + store_end, + ) + + def is_valid_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, _ = kv_states + return kv_blk_idx * num_kv_per_blk < kv_len + + def compute_with_kv_blk_in_cur_seq(kv_states): + kv_blk_idx, cur_buf_idx = kv_states + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( + get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, + cur_buf_idx)) + + @pl.when(next_heads_blk_idx < num_heads_blks) + def prefetch_next_kv_blk(): + # TODO(jevinjiang): reuse the same buffer if it is already prefetched! + # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and + # DMA to fixed size buffer! + next_async_copy_kv = create_kv_async_copy_descriptors( + next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx) + next_async_copy_kv.start() + + cur_async_copy_kv = create_kv_async_copy_descriptors( + heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, + head_dim, + ) + for kv_head_idx in range(num_kv_heads_per_blk): + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handlig for packed type that can start at + # unaligned position! + q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx + + num_q_heads_per_kv_head, :]) + k, v = strided_load_kv(kv_ref, kv_head_idx * 2, + num_combined_kv_heads_per_blk) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + acc_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :], + kv_blk_idx=kv_blk_idx, + ) + return kv_blk_idx + 1, next_buf_idx + + _, next_buf_idx = lax.while_loop( + is_valid_kv_blk_in_cur_seq, + compute_with_kv_blk_in_cur_seq, + (0, cur_buf_idx), # (kv_blk_idx, buf_idx) + ) + next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) + done = lax.select(q_end < q_len_end, done, 1) + return done, next_seq_idx, next_buf_idx + + _, seq_idx, buf_idx = lax.while_loop( + is_cur_q_blk_needed, + compute_with_cur_q_blk, + (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx) + ) + # Reset seq_idx for next kv_heads_blk if run out of seqs! + seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) + seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) + + +def cdiv(a, b): + assert b != 0 + return (a + b - 1) // b + + +def get_dtype_packing(dtype): + if dtype == jnp.float32: + return 1 + if dtype == jnp.bfloat16: + return 2 + if dtype == jnp.int8: + return 4 + if dtype == jnp.int4: + return 8 + raise ValueError(f"Not implemented: unsupported {dtype=}") + + +def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype, + kv_dtype): + q_packing = get_dtype_packing(q_dtype) + kv_packing = get_dtype_packing(kv_dtype) + + def can_be_xla_fully_tiled(x, packing): + if x % packing != 0: + return False + x //= packing + return x in (1, 2, 4, 8) or x % 8 == 0 + + # TODO(jevinjiang): support unaligned number of heads! + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): + raise ValueError( + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." + ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + assert num_q_heads % num_kv_heads == 0 + ratio = num_q_heads // num_kv_heads + # TODO(jevinjiang): we can choose smaller tiling for packed type if large + # second minor tiling is not on. + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling if num_combined_kv_heads % + max_combined_kv_tiling == 0 else num_combined_kv_heads) + min_q_heads = min_combined_kv_heads // 2 * ratio + if can_be_xla_fully_tiled(min_q_heads, q_packing): + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "mask_value", + "num_kv_pages_per_block", + "num_queries_per_block", + "vmem_limit_bytes", + "sliding_window", + "soft_cap", + ], +) +def ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + # TODO(jevinjiang): create a write_to_kv_cache kernel! + kv_pages: jax. + Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + *, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, +): + """Ragged paged attention that supports mixed prefill and decode. + + Args: + q: concatenated all sequences' queries. + kv_pages: paged K cache. Normally in HBM. + kv_lens: padded kv lengths. Only the first num_seqs values are valid. + page_indices: the first index indicates which page to use in the kv cache + for each sequence. Only the first num_seqs values are valid. + cu_q_lens: the cumulative sum of the effective query lengths. Similar to + kv_lens, only the first num_seqs+1 values are valid. + num_seqs: the dynamic number of sequences. + sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. + mask_value: mask value for causal mask. + num_kv_pages_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_block: number of kv pages to be processed in one flash + attention block in the pallas kernel. + vmem_limit_bytes: the vmem limit for the pallas kernel. + + Returns: + The output of the attention. + """ + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + num_q_tokens, num_q_heads, head_dim = q.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) + num_q_per_blk = num_queries_per_block + num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, + ) + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 + assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 + num_heads_blks = num_q_heads // num_q_heads_per_blk + grid = (num_heads_blks, num_q_blks) + + def q_index_map(heads_blk_idx, q_blk_idx, *_): + return (q_blk_idx, heads_blk_idx, 0) + + q_block_spec = pl.BlockSpec( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + q_index_map, + ) + in_specs = [ + q_block_spec, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + out_specs = q_block_spec + lm_scratch = pltpu.VMEM( + # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support + # unaligned slicing! + (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), + jnp.float32, + ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) + double_buf_scratch = pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_blk, + page_size, + num_combined_kv_heads_per_blk, + head_dim, + ), + kv_pages.dtype, + ) + scratch_shapes = [ + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. + lm_scratch, # l_ref + lm_scratch, # m_ref + acc_scratch, + ] + scalar_prefetches = ( + kv_lens, + page_indices, + cu_q_lens, + jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx + num_seqs, + ) + kernel = pl.pallas_call( + functools.partial( + ragged_paged_attention_kernel, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=( + "arbitrary", + "arbitrary", + ), + vmem_limit_bytes=vmem_limit_bytes, + ), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), + name="ragged_paged_attention_kernel", + ) + + return kernel(*scalar_prefetches, q, kv_pages) \ No newline at end of file From 9a5f8a2f0c7cc56091d6a66c051ea44a6452fd30 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Sat, 3 May 2025 03:33:07 +0000 Subject: [PATCH 02/16] add data @property to View class --- torchax/torchax/view.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index e9272871a4e8..fe235f19af76 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -377,8 +377,12 @@ def device(self): def jax_device(self): return self.jax().device - @property - def ndim(self): - return len(self.shape) - - __repr__ = __str__ + @property + def ndim(self): + return len(self.shape) + + @property + def data(self): + return self + + __repr__ = __str__ From 5df829900baf9a2018c81bdc8c3170a32c8c4365 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Sun, 4 May 2025 22:19:25 +0000 Subject: [PATCH 03/16] narrow() returns a View --- torchax/test/test_view.py | 23 ++++++++++++++++------- torchax/torchax/ops/jaten.py | 10 ++++++---- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index 3f5caee5f1f7..6dea3f73b268 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -14,14 +14,23 @@ def setUp(self): torch.manual_seed(0) torchax.enable_globally() + def test_narrow(self): + x = torch.zeros((10, 10), device="jax") + x = x.narrow(0, 0, 5).narrow(0, 0, 5) + y = torch.ones((5, 10), device="jax") + x.copy_(y) + self.assertEqual(type(x), View) + self.assertEqual(x.shape, (5, 10)) + self.assertEqual(x.sum(), 50) + def test_copy_(self): - x = torch.zeros((10, 10), device="jax") - y = torch.ones((5, 5), device="jax") - x[0:5, :][:, 0:5].copy_(y[:, :]) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), 25) - self.assertEqual(x.sum(), 25) + x = torch.zeros((10, 10), device="jax") + y = torch.ones((5, 5), device="jax") + x[0:5, :][:, 0:5].copy_(y[:, :]) + self.assertEqual(type(x), Tensor) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x[0:5, 0:5].sum(), 25) + self.assertEqual(x.sum(), 25) def test_transivity(self): x = torch.zeros((10, 10), device="jax") diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index fc8dcc71e466..3f5e30312a96 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -15,7 +15,7 @@ from torchax.ops import op_base, mappings from torchax import interop from torchax.ops import jax_reimplement -from torchax.view import View +from torchax.view import View, NarrowInfo from torchax.tensor import Tensor # Keys are OpOverload, value is a callable that takes # Tensor @@ -131,6 +131,8 @@ def _aten_copy(x, y, memory_format=None): if isinstance(x, View): x.update(y) return x + if isinstance(y, View): + y = y.torch() if x.ndim == 1 and y.ndim == 0: # case of torch.empty((1,)).copy_(tensor(N)) @@ -402,8 +404,8 @@ def _aten_triu(m, k): return jnp.triu(m, k) -@op(torch.ops.aten.slice) -@op(torch.ops.aten.slice_copy) +@op(torch.ops.aten.slice, is_jax_function=False, is_view_op=True) +@op(torch.ops.aten.slice_copy, is_jax_function=False, is_view_op=True) def _aten_slice(self, dim=0, start=None, end=None, step=1): if dim < 0: dim += self.ndim @@ -416,7 +418,7 @@ def _aten_slice(self, dim=0, start=None, end=None, step=1): dims.append(sl) else: dims.append(slice(None, None, None)) - return self[tuple(dims)] + return View(self, NarrowInfo(slices=tuple(dims)), env = self._env) @op(torch.ops.aten.detach) From 0c34344aadbc4c23df7664bfca2c0a02ee065ef0 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Mon, 5 May 2025 04:35:27 +0000 Subject: [PATCH 04/16] flatten returns a view view._setitem_ accepts tensor indices --- torchax/test/test_view.py | 43 +++++++++++++------------ torchax/torchax/ops/jaten.py | 13 ++++---- torchax/torchax/tensor.py | 15 +++++---- torchax/torchax/view.py | 62 ++++++++++++++++++++++++++++-------- 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index 6dea3f73b268..c48937137e97 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -10,27 +10,28 @@ class TrainTest(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - torchax.enable_globally() - - def test_narrow(self): - x = torch.zeros((10, 10), device="jax") - x = x.narrow(0, 0, 5).narrow(0, 0, 5) - y = torch.ones((5, 10), device="jax") - x.copy_(y) - self.assertEqual(type(x), View) - self.assertEqual(x.shape, (5, 10)) - self.assertEqual(x.sum(), 50) - - def test_copy_(self): - x = torch.zeros((10, 10), device="jax") - y = torch.ones((5, 5), device="jax") - x[0:5, :][:, 0:5].copy_(y[:, :]) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), 25) - self.assertEqual(x.sum(), 25) + def setUp(self): + torch.manual_seed(0) + torchax.enable_globally() + + def test_narrow(self): + x = torch.zeros((10, 10), device="jax") + x = x.narrow(0, 0, 5).narrow(0, 0, 5) + y = torch.ones((5, 10), device="jax") + x.copy_(y) + self.assertEqual(type(x), View) + self.assertEqual(x.shape, (5, 10)) + self.assertEqual(x.sum(), 50) + + def test_copy_(self): + x = torch.zeros((10, 10), device="jax") + y = torch.ones((5, 5), device="jax") + x[0:5, :][:, 0:5].copy_(y[:, :]) + self.assertEqual(type(x), Tensor) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x[0:5, 0:5].sum(), 25) + self.assertEqual(x.sum(), 25) + def test_transivity(self): x = torch.zeros((10, 10), device="jax") diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 3f5e30312a96..a9caf3319a96 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -15,7 +15,7 @@ from torchax.ops import op_base, mappings from torchax import interop from torchax.ops import jax_reimplement -from torchax.view import View, NarrowInfo +from torchax.view import View, NarrowInfo, ReshapeInfo from torchax.tensor import Tensor # Keys are OpOverload, value is a callable that takes # Tensor @@ -102,13 +102,14 @@ def inner(func): @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, + is_jax_function=False, ) def _aten_unsafe_view(x, shape): - return jnp.reshape(x, shape) + return View(x, ReshapeInfo(shape=shape), env=x._env) @op(torch.ops.aten.add.Tensor) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 66e2b55994b0..22f3c225e6e3 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -101,13 +101,14 @@ def shape(self): def ndim(self): return len(self._elem.shape) - def flatten(self, start_dim=0, end_dim=-1): - if end_dim == -1: - end_dim = self.ndim - new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:]) - new_elem = jnp.reshape(self._elem, new_shape) - return Tensor(new_elem, self._env) + # def flatten(self, start_dim=0, end_dim=-1): + # if end_dim == -1: + # end_dim = self.ndim + # new_shape = ( + # self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] + # ) + # new_elem = jnp.reshape(self._elem, new_shape) + # return Tensor(new_elem, self._env) # return torch.reshape(self, new_shape) def __setitem__(self, key, val): diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index fe235f19af76..216988f22f83 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Union, List, Tuple, Optional, Any, cast from abc import ABC, abstractmethod +import torch.utils._pytree as pytree # Reference to original PyTorch native functions # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml @@ -114,6 +115,36 @@ def update_tensor(self, new_value: jax.Array, def calculate_output_shape(self, source: jax.Array) -> List[int]: return source[self.slices].shape +class ReshapeInfo(ViewInfo): + """ + Represents a reshape operation on a tensor. + Handles operations like tensor.reshape(1, 2, 3) and tensor.reshape(-1, 1) + """ + + def __init__(self, shape: Tuple[int, ...]) -> None: + """ + Args: + shape: The shape to reshape the tensor to. + E.g. jax_array.reshape(shape) will return the transformed tensor. + """ + super().__init__(ViewInfoType.RESHAPE) + self.shape = shape + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ReshapeInfo): + return False + return self.shape == other.shape + + def transform_tensor(self, jax_array: jax.Array) -> jax.Array: + return jax_array.reshape(self.shape) + + def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: + original_shape = jax_array.shape + return jax_array.at[...].set(new_value.reshape(original_shape)) + + def calculate_output_shape(self, source: jax.Array) -> List[int]: + return source.reshape(self.shape).shape + class SelectInfo(ViewInfo): """ @@ -313,13 +344,16 @@ def update( intermediate_values.append( view_info.transform_tensor(intermediate_values[-1])) - # TODO: Investigate efficiency of this algorithm - # Update the source array with the new value by - # applying inverse transformations in reverse order - for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values)): - # Apply the inverse transformation to propagate changes back - new_values = view_info.update_tensor(new_values, parent_array) + # TODO: Investigate efficiency of this algorithm + # Update the source array with the new value by + # applying inverse transformations in reverse order + for view_info, parent_array in zip( + reversed(view_infos), reversed(intermediate_values) + ): + assert isinstance(new_values, jax.Array) + assert isinstance(parent_array, jax.Array) + # Apply the inverse transformation to propagate changes back + new_values = view_info.update_tensor(new_values, parent_array) # Update the source tensor with the new values self.replace_source_jax(new_values) @@ -362,12 +396,14 @@ def jax(self) -> jax.Array: result = view_info.transform_tensor(result) return result - def __setitem__(self, indexes, val): - view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] - self.update(view_infos=view_infos, new_values=val) - - def dim(self): - return self.ndim + def __setitem__(self, indexes, val): + # Handle tensor indexing + indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes) + view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] + self.update(view_infos=view_infos, new_values=val) + + def dim(self): + return self.ndim @property def device(self): From d37bed0c25a7bbf89253ec9866706d9cbf352554 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Mon, 5 May 2025 17:20:43 +0000 Subject: [PATCH 05/16] support index_copy_ --- torchax/test/test_view.py | 22 ++++++++++++++++++++++ torchax/torchax/decompositions.py | 5 +++-- torchax/torchax/ops/jaten.py | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index c48937137e97..ad3eba551ba3 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -13,6 +13,28 @@ class TrainTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) torchax.enable_globally() + + def test_index_copy_(self): + x = torch.zeros((10, 10), device="jax") + x_view = x[0, :] + indices = torch.arange(5, device="jax") + new_value = torch.ones((5,), device="jax") + x_view.index_copy_(0, indices, new_value) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x_view), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 5) + + def test_flatten(self): + x = torch.zeros((10, 10), device="jax") + x1 = x.flatten(0, 1) + y = torch.ones(100, device="jax") + x1.copy_(y) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x1), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 100) + def test_narrow(self): x = torch.zeros((10, 10), device="jax") diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index f116d42f3d67..3239b0d561b0 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -764,6 +764,7 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, ]) MUTABLE_DECOMPOSITION = [ - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, + torch.ops.aten.index_copy_.default, ] diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index a9caf3319a96..cb7aea7b4195 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -57,6 +57,7 @@ torch.ops.aten.scatter_add_: torch.ops.aten.scatter_add, torch.ops.aten.scatter_reduce_.two: torch.ops.aten.scatter_reduce, torch.ops.aten.scatter_: torch.ops.aten.scatter, + torch.ops.aten.index_put_: torch.ops.aten.index_put, } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. @@ -782,7 +783,6 @@ def _aten_empty_strided(sizes, stride, dtype=None, **kwargs): return jnp.empty(sizes, dtype=dtype) -@op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) def _aten_index_put(self, indexes, values, accumulate=False): indexes = [slice(None, None, None) if i is None else i for i in indexes] From 1f4239cfc0a6b06d4d3a925a5b335d4ce97d7279 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 00:04:11 +0000 Subject: [PATCH 06/16] delete commented out code --- torchax/torchax/tensor.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 22f3c225e6e3..287062befe32 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -101,15 +101,6 @@ def shape(self): def ndim(self): return len(self._elem.shape) - # def flatten(self, start_dim=0, end_dim=-1): - # if end_dim == -1: - # end_dim = self.ndim - # new_shape = ( - # self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] - # ) - # new_elem = jnp.reshape(self._elem, new_shape) - # return Tensor(new_elem, self._env) - # return torch.reshape(self, new_shape) def __setitem__(self, key, val): key, val = self._env.t2j_iso((key, val)) From 27d5d6608bdd2e24f8761b75752b2f9a8546eef5 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 00:04:29 +0000 Subject: [PATCH 07/16] simplify logic --- torchax/torchax/view.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 216988f22f83..067870a08619 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -139,8 +139,7 @@ def transform_tensor(self, jax_array: jax.Array) -> jax.Array: return jax_array.reshape(self.shape) def update_tensor(self, new_value: jax.Array, jax_array: jax.Array) -> jax.Array: - original_shape = jax_array.shape - return jax_array.at[...].set(new_value.reshape(original_shape)) + return new_value.reshape(jax_array.shape) def calculate_output_shape(self, source: jax.Array) -> List[int]: return source.reshape(self.shape).shape From f291a3ac745d56235f1c877a8220b2b2e4f08b8e Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 00:06:27 +0000 Subject: [PATCH 08/16] not needed --- torchax/torchax/ops/jtorch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 6249176436d5..b9217e803c2e 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -532,6 +532,9 @@ def _ragged_paged_attention( vmem_limit_bytes: int | None = None, ): +# if vmem_limit_bytes is None: +# vmem_limit_bytes = 64 * 1024 * 1024 + return ragged_paged_attention_kernel( q = q, kv_pages = kv_pages, From bb3763b0daf6ac30b7a7dcb46089d714bd937373 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 00:06:45 +0000 Subject: [PATCH 09/16] not needed --- torchax/torchax/ops/jtorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index b9217e803c2e..208c757f5240 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -532,8 +532,6 @@ def _ragged_paged_attention( vmem_limit_bytes: int | None = None, ): -# if vmem_limit_bytes is None: -# vmem_limit_bytes = 64 * 1024 * 1024 return ragged_paged_attention_kernel( q = q, From 12a157a92f889cdb7f96b0e0083de8e20af11c90 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 16:08:17 +0000 Subject: [PATCH 10/16] fix type check --- torchax/torchax/interop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 460441b308c2..7b675a239f32 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -12,6 +12,7 @@ from torchax import tensor from torchax import util import torchax +from torchax.view import View from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable @@ -179,7 +180,7 @@ def _jax_view(t: TorchValue) -> JaxValue: # t is an object from torch land # view it as-if it's a jax land object if isinstance(t, torch.Tensor): - assert isinstance(t, tensor.Tensor), type(t) + assert isinstance(t, tensor.Tensor) or isinstance(t, View), type(t) return t.jax() if isinstance(t, type(torch.int32)): return tensor.t2j_dtype(t) From f9d1df16d6dfca59efda5498b0c7d9d4783e3143 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 16:32:25 +0000 Subject: [PATCH 11/16] intercept torch.ops.xla.dynamo_set_buffer_donor_ --- torchax/torchax/ops/jtorch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 208c757f5240..31781da6aaac 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -514,6 +514,10 @@ def functional_linear(self, weights, bias=None): res += bias return res +@register_function(torch.ops.xla.dynamo_set_buffer_donor_) +def _dynamo_set_buffer_donor(self, donor): + pass + @register_function(torch.ops.xla.ragged_paged_attention) def _ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] From b673d4de0cb6eb18955572775f083c73bcd03d73 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 17:04:24 +0000 Subject: [PATCH 12/16] intercept torch._sync as no-op --- torchax/torchax/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index d18c983e252e..bc35e8d129a2 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -127,3 +127,6 @@ def compile(fn, options: Optional[CompileOptions] = None): raise RuntimeError('dynamo mode is not supported yet') elif options.mode == 'export': raise RuntimeError('export mode is not supported yet') + +# Intercept torch._sync as no-op +torch._sync = lambda *args, **kwargs: None From 0f4b40edd1a5ebd328ef8c21b6c928fb627d1365 Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 17:04:37 +0000 Subject: [PATCH 13/16] import ragged_attention from torch_xla --- torchax/torchax/ops/jtorch.py | 3 +- torchax/torchax/ops/ragged_attention.py | 1257 ----------------------- 2 files changed, 1 insertion(+), 1259 deletions(-) delete mode 100644 torchax/torchax/ops/ragged_attention.py diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 31781da6aaac..3c5163dbdfa7 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -17,7 +17,6 @@ import torchax.tensor from torchax.view import View, NarrowInfo import torch.utils._pytree as pytree -from torchax.ops.ragged_attention import ragged_paged_attention as ragged_paged_attention_kernel def register_function(torch_func, **kwargs): @@ -536,7 +535,7 @@ def _ragged_paged_attention( vmem_limit_bytes: int | None = None, ): - + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel return ragged_paged_attention_kernel( q = q, kv_pages = kv_pages, diff --git a/torchax/torchax/ops/ragged_attention.py b/torchax/torchax/ops/ragged_attention.py deleted file mode 100644 index fdc5a63adc8d..000000000000 --- a/torchax/torchax/ops/ragged_attention.py +++ /dev/null @@ -1,1257 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Copied from https://github.com/pytorch/xla/blob/master/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py -since the pypi torchxla package does not have the updated function. -""" - -"""TPU-Friendly Ragged Paged Attention kernel. - -This kernel offers a highly optimized implementation of ragged paged attention, -specifically designed for TPU and compatible with a wide range of model -specifications. It supports mixed prefill and decoding, enhancing throughput -during inference. -""" -import functools -import jax -from jax import lax -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu -import jax.numpy as jnp - -DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) -# The page size is too small. We only have 32 SREGs in TC. If the pages -# per seq is too large, SREGs will spill. -MAX_PAGES_PER_SEQ = 16 - -# key: -# - q_dtype_name -# - kv_dtype_name -# - num_q_heads_per_blk -# - num_kv_heads_per_blk -# - head_dim -# - page_size -# - max_num_batched_tokens -# - max_model_len = page_size * pages_per_seq -# value: -# - num_kv_pages_per_block -# - num_queries_per_block -TUNED_BLOCK_SIZES = { - 'TPU v6': { - # go/keep-sorted start - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), - # go/keep-sorted end - }, - 'TPU v5': { - # go/keep-sorted start - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), - ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), - ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), - # go/keep-sorted end - }, -} - - -def next_power_of_2(x: int): - """Finds the smallest power of 2 >= x using bit manipulation. - - Args: - x: The input number (should be an integer). - - Returns: - The smallest integer power of 2 that is >= x. - """ - assert x > 0 - if x == 1: - return 1 - return 1 << (x - 1).bit_length() - - -def simplify_key(key): - """Simplify the key to reduce the number of combinations.""" - ( - q_dtype, - kv_dtype, - num_q_heads_per_blk, - num_kv_heads_per_blk, - head_dim, - page_size, - max_num_batched_tokens, - pages_per_seq, - ) = key - return ( - jnp.dtype(q_dtype).name, - jnp.dtype(kv_dtype).name, - next_power_of_2(num_q_heads_per_blk), - next_power_of_2(num_kv_heads_per_blk), - (head_dim + 127) // 128 * 128, - next_power_of_2(page_size), - next_power_of_2(max_num_batched_tokens), - next_power_of_2(page_size * pages_per_seq), - ) - - -def get_tpu_version() -> int: - """Returns the numeric version of the TPU, or -1 if not on TPU.""" - kind = jax.devices()[0].device_kind - if 'TPU' not in kind: - return -1 - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == 'TPU v', kind - return int(kind[-1]) - - -def get_device_name(num_devices: int | None = None): - name = ' '.join(jax.devices()[0].device_kind.split()[:2]) - if num_devices is not None: - name += f'-{num_devices}' - return name - - -def get_tuned_block_sizes( - q_dtype, - kv_dtype, - num_q_heads_per_blk, - num_kv_heads_per_blk, - head_dim, - page_size, - max_num_batched_tokens, - pages_per_seq, -) -> tuple[int, int]: - """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" - tpu_version = get_tpu_version() - if tpu_version < 4: - raise NotImplementedError('TPU version must be 4 or higher.') - key = ( - q_dtype, - kv_dtype, - num_q_heads_per_blk, - num_kv_heads_per_blk, - head_dim, - page_size, - max_num_batched_tokens, - pages_per_seq, - ) - key = simplify_key(key) - device_name = get_device_name() - - # Default block sizes. - bkv, bq = (128, 32) - if tpu_version == 4: - # This default block size is not tuned, only make sure there's no - # OOM in vmem - bkv, bq = (32, 32) - elif device_name in TUNED_BLOCK_SIZES: - if key in TUNED_BLOCK_SIZES[device_name]: - bkv, bq = TUNED_BLOCK_SIZES[device_name][key] - return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) - - -def get_min_page_size(max_model_len, min_page_size=16): - """Recommended min page size for high-performance kernel.""" - return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) - - -class MultiPageAsyncCopyDescriptor: - """Descriptor for async copy of multiple K/V pages from HBM.""" - - def __init__( - self, - pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] - sem, - page_indices_ref, # i32[max_num_seqs, pages_per_seq] - offset, # [seq_idx, kv_pages_start] - ): - self._vmem_buf = vmem_buf - seq_id, kv_pages_start = offset - pages_per_seq = page_indices_ref.shape[1] - self._async_copies = [] - # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert - # a bunch of if-ops. Check the performance when we have benchmarking setup. - for i in range(vmem_buf.shape[0]): - page_idx = kv_pages_start + i - page_idx = jax.lax.select(page_idx < pages_per_seq, page_idx, - pages_per_seq - 1) - self._async_copies.append( - pltpu.make_async_copy( - pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], - vmem_buf.at[i], - sem, - )) - - def start(self): - """Starts the async copies.""" - for async_copy in self._async_copies: - async_copy.start() - - def wait(self): - for async_copy in self._async_copies: - async_copy.wait() - return self._vmem_buf - - -def ref_ragged_paged_attention( - queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - kv_lens: jax.Array, # i32[max_num_seqs] - page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs: jax.Array, # i32[1], - *, - sm_scale: float = 1.0, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = DEFAULT_MASK_VALUE, -): - static_validate_inputs( - queries, - kv_pages, - kv_lens, - page_indices, - cu_q_lens, - num_seqs, - sm_scale=sm_scale, - sliding_window=sliding_window, - soft_cap=soft_cap, - mask_value=mask_value, - ) - if mask_value is None: - mask_value = DEFAULT_MASK_VALUE - _, _, num_combined_kv_heads, head_dim = kv_pages.shape - assert num_combined_kv_heads % 2 == 0 - num_kv_heads = num_combined_kv_heads // 2 - num_q_heads = queries.shape[1] - assert num_q_heads % num_kv_heads == 0 - num_query_per_kv = num_q_heads // num_kv_heads - outputs = [] - for i in range(num_seqs[0]): - q_start = cu_q_lens[i] - q_end = cu_q_lens[i + 1] - q_len = q_end - q_start - kv_len = kv_lens[i] - indices = page_indices[i] - q = queries[q_start:q_end] - k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, - head_dim)[:kv_len] - v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, - head_dim)[:kv_len] - k = jnp.repeat(k, num_query_per_kv, axis=1) - v = jnp.repeat(v, num_query_per_kv, axis=1) - attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) - attn *= sm_scale - q_span = (kv_len - q_len) + jax.lax.broadcasted_iota( - jnp.int32, attn.shape, 1) - kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - mask = q_span < kv_span - if sliding_window is not None: - mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) - if soft_cap is not None: - attn = soft_cap * jnp.tanh(attn / soft_cap) - attn += jnp.where(mask, mask_value, 0.0) - attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) - out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) - outputs.append(out) - - return jnp.concatenate(outputs, axis=0) - - -# Expect to run these checks during runtime. -def dynamic_validate_inputs( - q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - kv_lens: jax.Array, # i32[max_num_seqs] - page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs: jax.Array, # i32[1] - *, - # These inputs are optional. If not specified, we will not validate them. - sm_scale: float | None = None, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = None, - # Kernel specific params. - num_kv_pages_per_block: int | None = None, - num_queries_per_block: int | None = None, - vmem_limit_bytes: int | None = None, -): - static_validate_inputs( - q, - kv_pages, - kv_lens, - page_indices, - cu_q_lens, - num_seqs, - sm_scale=sm_scale, - sliding_window=sliding_window, - soft_cap=soft_cap, - mask_value=mask_value, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, - vmem_limit_bytes=vmem_limit_bytes, - ) - max_num_batched_tokens = q.shape[0] - page_size = kv_pages.shape[1] - max_num_seqs, pages_per_seq = page_indices.shape - if num_seqs[0] > max_num_seqs: - raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") - max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = cdiv(max_kv_len, page_size) - if pages_per_seq < min_pages_per_seq: - raise ValueError( - f"{pages_per_seq=} must be greater or equal to" - f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") - if cu_q_lens[num_seqs[0]] > max_num_batched_tokens: - raise ValueError( - f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to" - f" {max_num_batched_tokens=}.") - for i in range(num_seqs[0]): - q_len = cu_q_lens[i + 1] - cu_q_lens[i] - kv_len = kv_lens[i] - if q_len > kv_len: - raise ValueError( - f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") - - -# Expect to run these checks during compile time. -def static_validate_inputs( - q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - kv_lens: jax.Array, # i32[max_num_seqs] - page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs: jax.Array, # i32[1] - *, - # These inputs are optional. If not specified, we will not validate them. - sm_scale: float | None = None, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = None, - # Kernel specific params. - num_kv_pages_per_block: int | None = None, - num_queries_per_block: int | None = None, - vmem_limit_bytes: int | None = None, -): - _, num_q_heads, head_dim = q.shape - _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape - assert num_combined_kv_heads % 2 == 0 - num_kv_heads = num_combined_kv_heads // 2 - max_num_seqs, pages_per_seq = page_indices.shape - if num_seqs.shape != (1,): - raise ValueError(f"{num_seqs.shape=} must be (1,)") - if head_dim_k != head_dim: - raise ValueError( - f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.") - if kv_lens.shape != (max_num_seqs,): - raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") - if cu_q_lens.shape != (max_num_seqs + 1,): - raise ValueError( - f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" - " `max_num_seqs` is `page_indices.shape[0]`.") - if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or - cu_q_lens.dtype != jnp.int32): - raise ValueError( - "The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be" - f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=}," - f" {cu_q_lens.dtype=}.") - if num_q_heads % num_kv_heads != 0: - raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") - if sliding_window is not None and sliding_window <= 0: - raise ValueError(f"{sliding_window=} must be positive.") - if soft_cap is not None and soft_cap == 0.0: - raise ValueError(f"{soft_cap=} must not be 0.0.") - if (num_kv_pages_per_block is not None and - not 0 < num_kv_pages_per_block <= pages_per_seq): - raise ValueError( - f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}].") - if num_queries_per_block is not None and num_queries_per_block <= 0: - raise ValueError(f"{num_queries_per_block=} must be positive.") - if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: - raise ValueError(f"{vmem_limit_bytes=} must be positive.") - del sm_scale # No constraints on sm_scale. - del mask_value # No consstraints on mask_value. - - -def ragged_paged_attention_kernel( - # Prefetch - kv_lens_ref, # [max_num_seqs] - page_indices_ref, # [max_num_seqs, pages_per_seq] - cu_q_lens_ref, # [max_num_seqs + 1] - seq_buf_idx_ref, - # TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs. - num_seqs_ref, - # Input - q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - # Output - o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - # Scratch - kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] - sems, # [2, 2] - l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] - m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] - acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - *, - sm_scale: float, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = DEFAULT_MASK_VALUE, -): - if mask_value is None: - mask_value = DEFAULT_MASK_VALUE - num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape - num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( - kv_bufs.shape) - num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 - num_kv_per_blk = num_kv_pages_per_blk * page_size - num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk - heads_blk_idx, q_blk_idx = ( - pl.program_id(0), - pl.program_id(1), - ) - num_heads_blks = pl.num_programs(0) - init_seq_idx = seq_buf_idx_ref[0] - init_buf_idx = seq_buf_idx_ref[1] - q_len_start = q_blk_idx * num_q_per_blk - q_len_end = q_len_start + num_q_per_blk - - def create_kv_async_copy_descriptors(heads_blk_idx, seq_idx, kv_blk_idx, - buf_idx): - offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_combined_kv_heads_per_blk - async_copy_kv = MultiPageAsyncCopyDescriptor( - kv_pages_hbm_ref.at[:, :, - pl.ds(heads_start, num_combined_kv_heads_per_blk - ), :], - kv_bufs.at[buf_idx], - sems.at[buf_idx], - page_indices_ref, - offset, - ) - return async_copy_kv - - # TODO(jevinjiang): Add these to Mosaic: - # 1. Support arbitrary strided load/store for any dtype. - # 2. Support arbitrary strided load/store for any last dimension. - def strided_load_kv(ref, start, step): - if ref.dtype == jnp.float32: - return ref[start::step, :], ref[start + 1::step, :] - packing = get_dtype_packing(ref.dtype) - assert ref.dtype == jnp.bfloat16 - assert step % packing == 0 - b_start = start // packing - b_step = step // packing - b_ref = ref.bitcast(jnp.uint32) - b = b_ref[b_start::b_step, :] - bk = b << 16 - bv = b & jnp.uint32(0xffff0000) - k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) - v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) - return k, v - - def fold_on_2nd_minor(vec): - assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 - assert len(vec.shape) >= 2 - last_dim = vec.shape[-1] - packing = get_dtype_packing(vec.dtype) - if vec.shape[-2] % packing != 0: - vec = vec.astype(jnp.float32) - return vec.reshape(-1, last_dim) - - @pl.when(heads_blk_idx + q_blk_idx == 0) - def prefetch_first_kv_blk(): - async_copy_kv = create_kv_async_copy_descriptors(heads_blk_idx, - init_seq_idx, 0, - init_buf_idx) - async_copy_kv.start() - - def is_cur_q_blk_needed(q_states): - done, cur_seq_idx, _ = q_states - should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], - cur_seq_idx < num_seqs) - return jnp.logical_and(done == 0, should_run) - - def compute_with_cur_q_blk(q_states): - done, cur_seq_idx, cur_buf_idx = q_states - q_start = cu_q_lens_ref[cur_seq_idx] - q_end = cu_q_lens_ref[cur_seq_idx + 1] - q_len = q_end - q_start - kv_len = kv_lens_ref[cur_seq_idx] - - def get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, - cur_buf_idx): - next_kv_blk_idx = kv_blk_idx + 1 - is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len - next_kv_blk_idx = lax.select( - is_last_kv_blk, - 0, - next_kv_blk_idx, - ) - is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end - next_seq_idx = lax.select( - is_last_kv_blk, - lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx), - cur_seq_idx, - ) - is_last_seq = next_seq_idx == num_seqs - next_seq_idx = lax.select( - is_last_seq, - 0, - next_seq_idx, - ) - next_heads_blk_idx = lax.select( - is_last_seq, - heads_blk_idx + 1, - heads_blk_idx, - ) - next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0) - return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx - - def flash_attention( - q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim] - k, # [num_kv_per_blk, head_dim] - v, # [num_kv_per_blk, head_dim] - head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] - *, - kv_blk_idx, - ): - assert q.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - head_dim, - ) - assert k.shape == ( - num_kv_per_blk, - head_dim, - ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" - assert v.shape == (num_kv_per_blk, head_dim) - assert head_m_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, - ) - assert head_l_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, - ) - assert head_acc_ref.shape == ( - num_q_per_blk, - num_q_heads_per_kv_head, - head_dim, - ) - kv_len_start = kv_blk_idx * num_kv_per_blk - - def masked_store(ref, val, start, end, group=1): - iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group - mask = jnp.logical_and(iota >= start, iota < end) - pl.store( - ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) - - qk = ( - jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) * - sm_scale) - store_start = jnp.maximum(q_start - q_len_start, 0) - store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) - - @pl.when(kv_blk_idx == 0) - def init_scratch_ref(): - masked_store( - head_m_ref, - jnp.full_like(head_m_ref, -jnp.inf), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_l_ref, - jnp.zeros_like(head_l_ref), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_acc_ref, - jnp.zeros_like(head_acc_ref), - store_start, - store_end, - ) - - row_ids = ((kv_len - q_len) + q_len_start - q_start + - jax.lax.broadcasted_iota( - jnp.int32, - (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), - 0, - ) // num_q_heads_per_kv_head) - col_ids = kv_len_start + jax.lax.broadcasted_iota( - jnp.int32, - (num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk), - 1, - ) - causal_mask = row_ids < col_ids - if sliding_window is not None: - causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window - >= col_ids) - if soft_cap is not None: - qk = soft_cap * jnp.tanh(qk / soft_cap) - qk += jnp.where(causal_mask, mask_value, 0.0) - m_curr = jnp.max(qk, axis=1, keepdims=True) - s_curr = jnp.exp(qk - m_curr) - qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32) - lm_store_shape = head_m_ref.shape - m_curr = jnp.broadcast_to(m_curr, lm_store_shape) - l_curr = jnp.broadcast_to( - s_curr.sum(axis=1, keepdims=True), lm_store_shape) - m_prev = head_m_ref[...] - l_prev = head_l_ref[...] - m_next = jnp.maximum(m_prev, m_curr) - masked_store(head_m_ref, m_next, store_start, store_end, - num_q_heads_per_kv_head) - alpha = jnp.exp(m_prev - m_next) - beta = jnp.exp(m_curr - m_next) - l_alpha = alpha * l_prev - l_next = l_alpha + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) - masked_store( - head_l_ref, - l_next_safe, - store_start, - store_end, - num_q_heads_per_kv_head, - ) - - def broadcast_to_shape(arr, shape): - if arr.shape == shape: - return arr - assert len(arr.shape) == len(shape) - assert arr.shape[0] == shape[0] - assert shape[1] % arr.shape[1] == 0 - # no-op concatenation. - return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])], - axis=1) - - o_curr = head_acc_ref[...].reshape(-1, head_dim) - l_alpha = broadcast_to_shape(l_alpha, qkv.shape) - beta = broadcast_to_shape(beta, qkv.shape) - l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) - out = lax.div( - l_alpha * o_curr + beta * qkv, - l_next_safe, - ) - masked_store( - head_acc_ref, - out.reshape(head_acc_ref.shape), - store_start, - store_end, - ) - - def is_valid_kv_blk_in_cur_seq(kv_states): - kv_blk_idx, _ = kv_states - return kv_blk_idx * num_kv_per_blk < kv_len - - def compute_with_kv_blk_in_cur_seq(kv_states): - kv_blk_idx, cur_buf_idx = kv_states - next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = ( - get_next_prefetch_ids(heads_blk_idx, cur_seq_idx, kv_blk_idx, - cur_buf_idx)) - - @pl.when(next_heads_blk_idx < num_heads_blks) - def prefetch_next_kv_blk(): - # TODO(jevinjiang): reuse the same buffer if it is already prefetched! - # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and - # DMA to fixed size buffer! - next_async_copy_kv = create_kv_async_copy_descriptors( - next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx) - next_async_copy_kv.start() - - cur_async_copy_kv = create_kv_async_copy_descriptors( - heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx) - kv_ref = cur_async_copy_kv.wait().reshape( - num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, - head_dim, - ) - for kv_head_idx in range(num_kv_heads_per_blk): - q_head_idx = kv_head_idx * num_q_heads_per_kv_head - # TODO(jevinjiang): extra handlig for packed type that can start at - # unaligned position! - q = fold_on_2nd_minor(q_ref[:, q_head_idx:q_head_idx + - num_q_heads_per_kv_head, :]) - k, v = strided_load_kv(kv_ref, kv_head_idx * 2, - num_combined_kv_heads_per_blk) - flash_attention( - q, - k, - v, - l_ref.at[kv_head_idx], - m_ref.at[kv_head_idx], - acc_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :], - kv_blk_idx=kv_blk_idx, - ) - return kv_blk_idx + 1, next_buf_idx - - _, next_buf_idx = lax.while_loop( - is_valid_kv_blk_in_cur_seq, - compute_with_kv_blk_in_cur_seq, - (0, cur_buf_idx), # (kv_blk_idx, buf_idx) - ) - next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx) - done = lax.select(q_end < q_len_end, done, 1) - return done, next_seq_idx, next_buf_idx - - _, seq_idx, buf_idx = lax.while_loop( - is_cur_q_blk_needed, - compute_with_cur_q_blk, - (0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx) - ) - # Reset seq_idx for next kv_heads_blk if run out of seqs! - seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) - seq_buf_idx_ref[1] = buf_idx - o_ref[...] = acc_ref[...].astype(q_ref.dtype) - - -def cdiv(a, b): - assert b != 0 - return (a + b - 1) // b - - -def get_dtype_packing(dtype): - if dtype == jnp.float32: - return 1 - if dtype == jnp.bfloat16: - return 2 - if dtype == jnp.int8: - return 4 - if dtype == jnp.int4: - return 8 - raise ValueError(f"Not implemented: unsupported {dtype=}") - - -def get_min_heads_per_blk(num_q_heads, num_combined_kv_heads, q_dtype, - kv_dtype): - q_packing = get_dtype_packing(q_dtype) - kv_packing = get_dtype_packing(kv_dtype) - - def can_be_xla_fully_tiled(x, packing): - if x % packing != 0: - return False - x //= packing - return x in (1, 2, 4, 8) or x % 8 == 0 - - # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): - raise ValueError( - f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." - ) - assert num_combined_kv_heads % 2 == 0 - num_kv_heads = num_combined_kv_heads // 2 - assert num_q_heads % num_kv_heads == 0 - ratio = num_q_heads // num_kv_heads - # TODO(jevinjiang): we can choose smaller tiling for packed type if large - # second minor tiling is not on. - max_combined_kv_tiling = 8 * kv_packing - min_combined_kv_heads = ( - max_combined_kv_tiling if num_combined_kv_heads % - max_combined_kv_tiling == 0 else num_combined_kv_heads) - min_q_heads = min_combined_kv_heads // 2 * ratio - if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_combined_kv_heads - return num_q_heads, num_combined_kv_heads - - -@functools.partial( - jax.jit, - static_argnames=[ - "sm_scale", - "mask_value", - "num_kv_pages_per_block", - "num_queries_per_block", - "vmem_limit_bytes", - "sliding_window", - "soft_cap", - ], -) -def ragged_paged_attention( - q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - # TODO(jevinjiang): create a write_to_kv_cache kernel! - kv_pages: jax. - Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - kv_lens: jax.Array, # i32[max_num_seqs] - page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs: jax.Array, # i32[1] - *, - sm_scale: float = 1.0, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int | None = None, - num_queries_per_block: int | None = None, - vmem_limit_bytes: int | None = None, -): - """Ragged paged attention that supports mixed prefill and decode. - - Args: - q: concatenated all sequences' queries. - kv_pages: paged K cache. Normally in HBM. - kv_lens: padded kv lengths. Only the first num_seqs values are valid. - page_indices: the first index indicates which page to use in the kv cache - for each sequence. Only the first num_seqs values are valid. - cu_q_lens: the cumulative sum of the effective query lengths. Similar to - kv_lens, only the first num_seqs+1 values are valid. - num_seqs: the dynamic number of sequences. - sm_scale: the softmax scale which will be applied to the Q@K^T. - sliding_window: the sliding window size for the attention. - soft_cap: the logit soft cap for the attention. - mask_value: mask value for causal mask. - num_kv_pages_per_block: number of kv pages to be processed in one flash - attention block in the pallas kernel. - num_queries_per_block: number of kv pages to be processed in one flash - attention block in the pallas kernel. - vmem_limit_bytes: the vmem limit for the pallas kernel. - - Returns: - The output of the attention. - """ - static_validate_inputs( - q, - kv_pages, - kv_lens, - page_indices, - cu_q_lens, - num_seqs, - sm_scale=sm_scale, - sliding_window=sliding_window, - soft_cap=soft_cap, - mask_value=mask_value, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, - vmem_limit_bytes=vmem_limit_bytes, - ) - if mask_value is None: - mask_value = DEFAULT_MASK_VALUE - num_q_tokens, num_q_heads, head_dim = q.shape - _, page_size, num_combined_kv_heads, _ = kv_pages.shape - assert num_combined_kv_heads % 2 == 0 - num_kv_heads = num_combined_kv_heads // 2 - _, pages_per_seq = page_indices.shape - num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) - num_q_per_blk = num_queries_per_block - num_kv_pages_per_blk = num_kv_pages_per_block - if num_q_per_blk is None or num_kv_pages_per_blk is None: - num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( - q.dtype, - kv_pages.dtype, - num_q_heads_per_blk, - num_combined_kv_heads_per_blk // 2, - head_dim, - page_size, - num_q_tokens, - pages_per_seq, - ) - num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = cdiv(num_q_tokens, num_q_per_blk) - assert num_combined_kv_heads_per_blk % 2 == 0 - num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 - assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 - num_heads_blks = num_q_heads // num_q_heads_per_blk - grid = (num_heads_blks, num_q_blks) - - def q_index_map(heads_blk_idx, q_blk_idx, *_): - return (q_blk_idx, heads_blk_idx, 0) - - q_block_spec = pl.BlockSpec( - (num_q_per_blk, num_q_heads_per_blk, head_dim), - q_index_map, - ) - in_specs = [ - q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - ] - out_specs = q_block_spec - lm_scratch = pltpu.VMEM( - # TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support - # unaligned slicing! - (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), - jnp.float32, - ) - acc_scratch = pltpu.VMEM( - (num_q_per_blk, num_q_heads_per_blk, head_dim), - jnp.float32, - ) - double_buf_scratch = pltpu.VMEM( - ( - 2, # For double buffering during DMA copies. - num_kv_pages_per_blk, - page_size, - num_combined_kv_heads_per_blk, - head_dim, - ), - kv_pages.dtype, - ) - scratch_shapes = [ - double_buf_scratch, # kv_bufs - pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. - lm_scratch, # l_ref - lm_scratch, # m_ref - acc_scratch, - ] - scalar_prefetches = ( - kv_lens, - page_indices, - cu_q_lens, - jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx - num_seqs, - ) - kernel = pl.pallas_call( - functools.partial( - ragged_paged_attention_kernel, - sm_scale=sm_scale, - sliding_window=sliding_window, - soft_cap=soft_cap, - mask_value=mask_value, - ), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=len(scalar_prefetches), - in_specs=in_specs, - out_specs=out_specs, - grid=grid, - scratch_shapes=scratch_shapes, - ), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=( - "arbitrary", - "arbitrary", - ), - vmem_limit_bytes=vmem_limit_bytes, - ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), - name="ragged_paged_attention_kernel", - ) - - return kernel(*scalar_prefetches, q, kv_pages) \ No newline at end of file From 55b339f2247073522ef535f0ae94648312ae5f3f Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Fri, 9 May 2025 17:14:59 +0000 Subject: [PATCH 14/16] wrap xla ops in try catch --- torchax/torchax/ops/jtorch.py | 81 +++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 3c5163dbdfa7..6e6abc064af3 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -513,43 +513,50 @@ def functional_linear(self, weights, bias=None): res += bias return res -@register_function(torch.ops.xla.dynamo_set_buffer_donor_) -def _dynamo_set_buffer_donor(self, donor): +try: + # TODO: Currently the following ops are wrapped in the try + # catch block because torch.ops.xla is not in the torch ops + # registry. Either we import torch_xla in the upper level, + # or modify the the register_function to support this. + @register_function(torch.ops.xla.dynamo_set_buffer_donor_) + def _dynamo_set_buffer_donor(self, donor): + pass + + @register_function(torch.ops.xla.ragged_paged_attention) + def _ragged_paged_attention( + q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] + kv_lens: jax.Array, # i32[max_num_seqs] + page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] + cu_q_lens: jax.Array, # i32[max_num_seqs + 1] + num_seqs: jax.Array, # i32[1] + use_kernel: bool = True, + sm_scale: float = 1.0, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, + ): + + from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel + return ragged_paged_attention_kernel( + q = q, + kv_pages = kv_pages, + kv_lens = kv_lens, + page_indices = page_indices, + cu_q_lens = cu_q_lens, + num_seqs = num_seqs, + sm_scale = sm_scale, + sliding_window = sliding_window, + soft_cap = soft_cap, + mask_value = mask_value, + num_kv_pages_per_block = num_kv_pages_per_block, + num_queries_per_block = num_queries_per_block, + vmem_limit_bytes = vmem_limit_bytes, + ) +except Exception as e: pass -@register_function(torch.ops.xla.ragged_paged_attention) -def _ragged_paged_attention( - q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] - kv_lens: jax.Array, # i32[max_num_seqs] - page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] - cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs: jax.Array, # i32[1] - use_kernel: bool = True, - sm_scale: float = 1.0, - sliding_window: int | None = None, - soft_cap: float | None = None, - mask_value: float | None = None, - num_kv_pages_per_block: int | None = None, - num_queries_per_block: int | None = None, - vmem_limit_bytes: int | None = None, -): - - from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel - return ragged_paged_attention_kernel( - q = q, - kv_pages = kv_pages, - kv_lens = kv_lens, - page_indices = page_indices, - cu_q_lens = cu_q_lens, - num_seqs = num_seqs, - sm_scale = sm_scale, - sliding_window = sliding_window, - soft_cap = soft_cap, - mask_value = mask_value, - num_kv_pages_per_block = num_kv_pages_per_block, - num_queries_per_block = num_queries_per_block, - vmem_limit_bytes = vmem_limit_bytes, -) - From 25eb93889112f2a1157dc997be434d491ca0a92c Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Mon, 12 May 2025 09:38:03 -0700 Subject: [PATCH 15/16] formatting --- torchax/test/test_view.py | 86 +++++++++++++++---------------- torchax/torchax/decompositions.py | 6 +-- torchax/torchax/ops/jtorch.py | 9 ++-- torchax/torchax/view.py | 51 +++++++++--------- 4 files changed, 75 insertions(+), 77 deletions(-) diff --git a/torchax/test/test_view.py b/torchax/test/test_view.py index ad3eba551ba3..576435ea9b01 100644 --- a/torchax/test/test_view.py +++ b/torchax/test/test_view.py @@ -10,50 +10,48 @@ class TrainTest(unittest.TestCase): - def setUp(self): - torch.manual_seed(0) - torchax.enable_globally() - - def test_index_copy_(self): - x = torch.zeros((10, 10), device="jax") - x_view = x[0, :] - indices = torch.arange(5, device="jax") - new_value = torch.ones((5,), device="jax") - x_view.index_copy_(0, indices, new_value) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(x_view), View) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 5) - - def test_flatten(self): - x = torch.zeros((10, 10), device="jax") - x1 = x.flatten(0, 1) - y = torch.ones(100, device="jax") - x1.copy_(y) - self.assertEqual(type(x), Tensor) - self.assertEqual(type(x1), View) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x.sum(), 100) - - - def test_narrow(self): - x = torch.zeros((10, 10), device="jax") - x = x.narrow(0, 0, 5).narrow(0, 0, 5) - y = torch.ones((5, 10), device="jax") - x.copy_(y) - self.assertEqual(type(x), View) - self.assertEqual(x.shape, (5, 10)) - self.assertEqual(x.sum(), 50) - - def test_copy_(self): - x = torch.zeros((10, 10), device="jax") - y = torch.ones((5, 5), device="jax") - x[0:5, :][:, 0:5].copy_(y[:, :]) - self.assertEqual(type(x), Tensor) - self.assertEqual(x.shape, (10, 10)) - self.assertEqual(x[0:5, 0:5].sum(), 25) - self.assertEqual(x.sum(), 25) - + def setUp(self): + torch.manual_seed(0) + torchax.enable_globally() + + def test_index_copy_(self): + x = torch.zeros((10, 10), device="jax") + x_view = x[0, :] + indices = torch.arange(5, device="jax") + new_value = torch.ones((5,), device="jax") + x_view.index_copy_(0, indices, new_value) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x_view), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 5) + + def test_flatten(self): + x = torch.zeros((10, 10), device="jax") + x1 = x.flatten(0, 1) + y = torch.ones(100, device="jax") + x1.copy_(y) + self.assertEqual(type(x), Tensor) + self.assertEqual(type(x1), View) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x.sum(), 100) + + def test_narrow(self): + x = torch.zeros((10, 10), device="jax") + x = x.narrow(0, 0, 5).narrow(0, 0, 5) + y = torch.ones((5, 10), device="jax") + x.copy_(y) + self.assertEqual(type(x), View) + self.assertEqual(x.shape, (5, 10)) + self.assertEqual(x.sum(), 50) + + def test_copy_(self): + x = torch.zeros((10, 10), device="jax") + y = torch.ones((5, 5), device="jax") + x[0:5, :][:, 0:5].copy_(y[:, :]) + self.assertEqual(type(x), Tensor) + self.assertEqual(x.shape, (10, 10)) + self.assertEqual(x[0:5, 0:5].sum(), 25) + self.assertEqual(x.sum(), 25) def test_transivity(self): x = torch.zeros((10, 10), device="jax") diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index 3239b0d561b0..47bbb535e253 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -764,7 +764,7 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, ]) MUTABLE_DECOMPOSITION = [ - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, - torch.ops.aten.index_copy_.default, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, + torch.ops.aten.index_copy_.default, ] diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 6e6abc064af3..390476be91c4 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -508,10 +508,11 @@ def linalg_tensorsolve(A, b, dims=None): @register_function(torch.nn.functional.linear) def functional_linear(self, weights, bias=None): - res = jnp.einsum("...a,ba->...b", self, weights) - if bias is not None: - res += bias - return res + res = jnp.einsum("...a,ba->...b", self, weights) + if bias is not None: + res += bias + return res + try: # TODO: Currently the following ops are wrapped in the try diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 067870a08619..6068a0eebef4 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -343,16 +343,15 @@ def update( intermediate_values.append( view_info.transform_tensor(intermediate_values[-1])) - # TODO: Investigate efficiency of this algorithm - # Update the source array with the new value by - # applying inverse transformations in reverse order - for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values) - ): - assert isinstance(new_values, jax.Array) - assert isinstance(parent_array, jax.Array) - # Apply the inverse transformation to propagate changes back - new_values = view_info.update_tensor(new_values, parent_array) + # TODO: Investigate efficiency of this algorithm + # Update the source array with the new value by + # applying inverse transformations in reverse order + for view_info, parent_array in zip( + reversed(view_infos), reversed(intermediate_values)): + assert isinstance(new_values, jax.Array) + assert isinstance(parent_array, jax.Array) + # Apply the inverse transformation to propagate changes back + new_values = view_info.update_tensor(new_values, parent_array) # Update the source tensor with the new values self.replace_source_jax(new_values) @@ -395,14 +394,14 @@ def jax(self) -> jax.Array: result = view_info.transform_tensor(result) return result - def __setitem__(self, indexes, val): - # Handle tensor indexing - indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes) - view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] - self.update(view_infos=view_infos, new_values=val) - - def dim(self): - return self.ndim + def __setitem__(self, indexes, val): + # Handle tensor indexing + indexes = pytree.tree_map(lambda x: x.jax() if isinstance(x, torch.Tensor) else x, indexes) + view_infos = self.get_transformation_chain() + [NarrowInfo(indexes)] + self.update(view_infos=view_infos, new_values=val) + + def dim(self): + return self.ndim @property def device(self): @@ -412,12 +411,12 @@ def device(self): def jax_device(self): return self.jax().device - @property - def ndim(self): - return len(self.shape) + @property + def ndim(self): + return len(self.shape) + + @property + def data(self): + return self - @property - def data(self): - return self - - __repr__ = __str__ + __repr__ = __str__ From eb0258b443788a4baf516bbbef427f05d96ff4be Mon Sep 17 00:00:00 2001 From: wenxindongwork Date: Mon, 12 May 2025 10:01:04 -0700 Subject: [PATCH 16/16] fix some unit tests --- torchax/torchax/ops/jtorch.py | 4 ++-- torchax/torchax/tensor.py | 2 +- torchax/torchax/view.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 390476be91c4..1443d064ab85 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -269,7 +269,7 @@ def getitem(self, indexes): elif isinstance(indexes, list): indexes = tuple(indexes) - def is_narrow_slicing(): + def is_view_slicing(): tensor_free = not pytree.tree_any( lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), indexes) @@ -277,7 +277,7 @@ def is_narrow_slicing(): [False if isinstance(x, list) else True for x in indexes]) return tensor_free and list_free - if is_narrow_slicing(): + if is_view_slicing(): return View(self, view_info=NarrowInfo(indexes), env=self._env) indexes = self._env.t2j_iso(indexes) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 287062befe32..ca386406d10b 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -373,7 +373,7 @@ def load_ops(self): ) def _to_copy(self, the_tensor, new_dtype, new_device): - if isinstance(the_tensor, Tensor): + if isinstance(the_tensor, Tensor) or isinstance(the_tensor, View): arr = the_tensor.jax() if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 6068a0eebef4..104a025de50c 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -385,6 +385,10 @@ def create_sub_view(self, view_info: ViewInfo) -> "View": def __str__(self) -> str: return f"View({self.torch()})" + @property + def _elem(self) -> jax.Array: + return self.jax() + def jax(self) -> jax.Array: """ Returns a copy of the source tensor after transformations. @@ -420,3 +424,15 @@ def data(self): return self __repr__ = __str__ + + +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_std_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_masked_var_cpu_int64 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_linear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_interpolate_trilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_nn_functional_upsample_bilinear_cpu_float32 - NotImplementedError: Cannot copy out of meta tensor; no data! +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_float32 - AttributeError: 'View' object has no attribute '_elem' +# FAILED test/test_ops.py::TestOpInfoCPU::test_reference_eager_take_cpu_int64 - AttributeError: 'View' object has no attribute '_elem' \ No newline at end of file