diff --git a/.gitignore b/.gitignore index 2ffda46d..d83dad6f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ __pycache__/ *.py[cod] *$py.class - # C extensions *.so @@ -98,6 +97,7 @@ celerybeat-schedule # Environments .env +.history .venv env/ venv/ diff --git a/padded_flash_attn.py b/padded_flash_attn.py new file mode 100644 index 00000000..e0df9788 --- /dev/null +++ b/padded_flash_attn.py @@ -0,0 +1,415 @@ +from functools import partial +import math +import jax +from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes +from jaxtyping import Array, Float16 +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax import numpy as jnp, lax +import numpy as np + + +NUM_LANES=128 +NUM_SUBLANES=8 +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) +QK_DOT_DIM_NUMBERS_SEQ_MAJOR = (((1,), (1,)), ((), ())) # RHS transposed +SV_DOT_DIM_NUMBERS_SEQ_MAJOR = (((1,), (0,)), ((), ())) # standard matmul +save_residuals=False + +DensePaddedAttentionReturnType = ( + Float16[Array, "heads q_seq_len head_dim"] | + tuple[ + Float16[Array, "heads q_seq_len head_dim"] , # out + tuple[ + Float16[Array, "heads q_seq_len lanes"], # l + Float16[Array, "heads q_seq_len lanes"] # m + ] + ] +) + +def _dense_padded_flash_attn_fwd_kernel( + # Inputs + q_ref, + k_ref, + v_ref, + # Outputs + o_ref, + logsumexp_ref, + max_logits_ref, + # Scratch + m_scratch_ref, + l_scratch_ref, + o_scratch_ref, + mask_scratch_ref, + qk_scratch_ref, + # statics + *, + kv_padding: int, + mask_value:float, + kv_steps: int, + bkv_compute:int, + head_dim_v: int, +): + h = pl.program_id(0) + bq_i = pl.program_id(1) + bkv_j = pl.program_id(2) + + # initialize accumulation tensors in scratch every bq_i + should_initialize = bkv_j == 0 + is_last_k_block = bkv_j == kv_steps - 1 + padding_exists = kv_padding > 0 + masking_is_needed_for_block = is_last_k_block & padding_exists + @pl.when(should_initialize) + def init(): + o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + + @pl.when(masking_is_needed_for_block) + def init_actual_mask(): + # Initialize the full (bq, bkv) mask for the last block + col_indices = jnp.arange(bkv_compute) + # Calculate how many tokens in this chunk are NOT padding + num_real_tokens = bkv_compute - kv_padding # Shape (block_kv,) + mask_row = col_indices < num_real_tokens # True for real, False for padding + mask_scratch_ref[...] = jnp.broadcast_to(mask_row, mask_scratch_ref.shape) + + + def mask(qk, bkv_slice): + mask_arr = mask_scratch_ref[:, bkv_slice] + qk = jnp.where(mask_arr, qk, mask_value) + return qk + + should_write = bkv_j == kv_steps - 1 + padding_exists = kv_padding > 0 + masking_is_needed_for_block = should_write & padding_exists + num_iters = ( + k_ref.shape[0] // bkv_compute + ) + # body for lax.fori loop over bkv compute blocks + def body(kv_compute_index, _): + # Reads from VMEM to VREG + # compute BKV_COMPUTE slice from BK k sequence + # idea compute slices before in prefetch scalars + slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) + # Softmax stats + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + # BQ read + q = q_ref[...] + # BKV_COMPUTE read + k = k_ref[slice_k, :] # HEAD DIM minor + # QK + qk_dims = QK_DOT_DIM_NUMBERS_SEQ_MAJOR # TODO: option to support K transpose + qk = lax.dot_general(q, k, qk_dims, preferred_element_type=jnp.float32) + qk_scratch_ref[...] = qk + + @pl.when(jnp.logical_and(kv_compute_index==num_iters-1,should_write)) + def mask(): + # mask_arr = ] + # qk_t = qk_scratch_ref[...] + qk_scratch_ref[...] = jnp.where(mask_scratch_ref[...], qk_scratch_ref[...], mask_value) + + qk = qk_scratch_ref[...] + # Running max + m_curr = qk.max(axis=-1)[:, None] + m_next = jnp.maximum(m_prev, m_curr) + + # Current numerator + bkv_repeats = bkv_compute//NUM_LANES + exp = jnp.exp + s_curr = exp(qk - pltpu.repeat(m_next, bkv_repeats, axis=1)) + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + + # Correction factor + alpha = exp(m_prev - m_next) + + # Accumulate denominator + l_next = l_curr + alpha * l_prev + + # Update softmax stats + m_scratch_ref[...], l_scratch_ref[...] = m_next, l_next + + # numerator * V + sv_dims = SV_DOT_DIM_NUMBERS_SEQ_MAJOR + v = v_ref[slice_k, :] + v = v.astype(jnp.float32) + o_curr = lax.dot_general(s_curr, v, sv_dims) + + # Accumulate unnormalized O + head_dim_v_repeats = head_dim_v//NUM_LANES + alpha_o = pltpu.repeat(alpha, head_dim_v_repeats, axis=1) + o_scratch_ref[...] = alpha_o * o_scratch_ref[...] + o_curr + + + lax.fori_loop( + lower=0, + upper=num_iters, + body_fun=body, + init_val=None, + unroll=True) + + @pl.when(should_write) + def end(): + l = l_scratch_ref[...] + head_dim_v_repeats = head_dim_v//NUM_LANES + l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=1) + o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + + # UNCOMMENT FOR SAVING SOFTMAX STATS + # if logsumexp_ref is not None: + # log = jnp.log # allow log base 2 + # logsumexp = m_scratch_ref[...] + log(l) + # logsumexp_ref[...] = logsumexp.astype(logsumexp_ref.dtype) + # if max_logits_ref is not None: + # max_logits_ref[...] = m_scratch_ref[...].astype(max_logits_ref.dtype) + + + +def _dense_padded_flash_attn_custom_fwd( + q: Float16[Array, "heads q_seq_len head_dim"], + k: Float16[Array, "heads kv_seq_len head_dim"], + v: Float16[Array, "heads kv_seq_len head_dim"], + block_sizes: BlockSizes, + kv_padding: int, +)-> DensePaddedAttentionReturnType: + head_dim = q.shape[-1] + num_heads = q.shape[0] + q_sequence_len = q.shape[1] + kv_sequence_len = k.shape[1] + + # Block specs + # Input block specs + in_specs = [ + # BQ + pl.BlockSpec( + block_shape= (None, block_sizes.block_q, head_dim), + index_map= lambda h, bq_i, bkv_j : (h, bq_i, 0) + ), + # BK + pl.BlockSpec( + block_shape= (None, block_sizes.block_kv, head_dim), + index_map= lambda h, bq_i, bkv_j : (h, bkv_j, 0) + ), + # BV + pl.BlockSpec( + block_shape= (None, block_sizes.block_kv, head_dim), + index_map= lambda h, bq_i, bkv_j : (h, bkv_j, 0) + ), + ] + # Output block specs + out_specs = [ + # out + pl.BlockSpec( + block_shape= (None, block_sizes.block_q, head_dim), + index_map= lambda h, bq_i, bkv_j : (h, bq_i, 0) + ), + ] + # Output Shape + out_shapes = [ + # out + jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), + ] + + if save_residuals: + out_specs += [pl.BlockSpec( # logsumexp + block_shape=(None, block_sizes.block_q, NUM_LANES), + index_map= lambda h, bq_i, bkv_j : (h, bq_i, 0) + ), + # max_logits + pl.BlockSpec( + block_shape=(None, block_sizes.block_q, NUM_LANES), + index_map= lambda h, bq_i, bkv_j : (h, bq_i, 0) + )] + + out_shapes += [ + # logsumexp + jax.ShapeDtypeStruct((num_heads, q_sequence_len, NUM_LANES), jnp.float32), + # max_logits + jax.ShapeDtypeStruct((num_heads, q_sequence_len, NUM_LANES), jnp.float32), + ] + else: + out_specs += [None, None] + out_shapes += [None, None] + + # Scratch shapes m,l,o,mask + scratch_shapes = [ + pltpu.VMEM( # m_scratch + shape=(block_sizes.block_q, NUM_LANES), + dtype=jnp.float32), + pltpu.VMEM( # l_scratch + shape=(block_sizes.block_q, NUM_LANES), + dtype=jnp.float32), + pltpu.VMEM( # o_scratch + shape=(block_sizes.block_q, head_dim), + dtype=jnp.float32), + pltpu.VMEM( # mask + shape=(block_sizes.block_q, block_sizes.block_kv_compute), + dtype=jnp.bool), + pltpu.VMEM( + shape=(block_sizes.block_q, block_sizes.block_kv_compute), + dtype=jnp.float32 + ) + ] + + # Grid + + num_bq_blocks = q_sequence_len // block_sizes.block_q + num_bkv_blocks = kv_sequence_len // block_sizes.block_kv + grid = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid = (num_heads, num_bq_blocks, num_bkv_blocks), + in_specs=in_specs, + out_specs=out_specs, + scratch_shapes=scratch_shapes + ) + + # Compiler Params + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary") + ) + + + # Cost estimate + def _bytes(x: jax.Array | jax.ShapeDtypeStruct | None) -> int: + if x is None: + return 0 + if jnp.issubdtype(x.dtype, jnp.floating): + info = jnp.finfo + elif jnp.issubdtype(x.dtype, jnp.integer): + info = jnp.iinfo + else: + raise ValueError(f"Unsupported dtype: {x.dtype}") + return math.ceil(math.prod(x.shape) * info(x.dtype).bits / 8) + + def _fwd_cost_estimate( + q: Float16[Array, "heads q_seq_len head_dim"], + k: Float16[Array, "heads kv_seq_len head_dim"], + v: Float16[Array, "heads kv_seq_len head_dim"], + out_shapes: list[jax.ShapeDtypeStruct], + ) -> pl.CostEstimate | None: + num_q_heads, q_seq_len, head_dim_qk = q.shape + kv_seq_len, head_dim_v = v.shape[-2:] + + matmul_flops = ( + 2 * q_seq_len * kv_seq_len * head_dim_qk + + 2 * kv_seq_len * kv_seq_len * head_dim_v + ) + + # This is an upper bound because `mask_sparsity` is actually the mean + # sparsity of the non-fully masked **blocks**. + total_flops = num_q_heads * matmul_flops + + # Count expensive exp() calls + transcendentals = num_q_heads * q_seq_len * kv_seq_len + + inputs_ = [q, k, v] + input_bytes = sum(map(_bytes, inputs_)) + output_bytes = sum(map(_bytes, out_shapes)) + return pl.CostEstimate( + flops=int(total_flops), + transcendentals=int(transcendentals), + bytes_accessed=int(input_bytes + output_bytes), + ) + + vmem_inputs = [ + q, + k, + v, + ] + cost_estimate = _fwd_cost_estimate(*vmem_inputs, out_shapes) + + ## Pallas call + kv_steps = kv_sequence_len//block_sizes.block_kv + kv_compute_iters = block_sizes.block_kv//block_sizes.block_kv_compute + dense_padded_attn_fwd_kernel = partial( + _dense_padded_flash_attn_fwd_kernel, + kv_padding=kv_padding, + mask_value = DEFAULT_MASK_VALUE, + kv_steps=kv_steps, + bkv_compute=block_sizes.block_kv_compute, + head_dim_v=head_dim + ) + with jax.named_scope("dense_padded_flash_attn_fwd"): + all_out = pl.pallas_call( + kernel=dense_padded_attn_fwd_kernel, + grid_spec = grid, + compiler_params=compiler_params, + cost_estimate=cost_estimate, + out_shape=out_shapes, + name="dense_padded_flash_attn_fwd", + )( + q, k, v + ) + out, logsumexp, max_logits = all_out + return out, (logsumexp, max_logits) + + + +def _dense_padded_flash_attn_custom_vjp( + q: Float16[Array, "heads q_seq_len head_dim"], + k: Float16[Array, "heads kv_seq_len head_dim"], + v: Float16[Array, "heads kv_seq_len head_dim"], + block_sizes: BlockSizes, + kv_padding: int, +)-> DensePaddedAttentionReturnType: + return _dense_padded_flash_attn_custom_fwd( + q, k, v, block_sizes, kv_padding + ) + + +@partial( + jax.jit, + static_argnames=("block_sizes", "kv_padding"), +) +def _dense_padded_flash_attention( + q: Float16[Array, "heads q_seq_len head_dim"], + k: Float16[Array, "heads kv_seq_len head_dim"], + v: Float16[Array, "heads kv_seq_len head_dim"], + *, + block_sizes: BlockSizes, + kv_padding: int, +)-> DensePaddedAttentionReturnType: + return _dense_padded_flash_attn_custom_vjp( + q, k, v, block_sizes, kv_padding + ) + + +@jax.tree_util.register_pytree_node_class +class DensePaddedAttention: + + def __init__(self, block_sizes: BlockSizes, kv_padding: int): + self.block_sizes = block_sizes + self.kv_padding = kv_padding + def __call__(self, + q: Float16[Array, "heads q_seq_len head_dim"], + k: Float16[Array, "heads kv_seq_len head_dim"], + v: Float16[Array, "heads kv_seq_len head_dim"], + ): + return _dense_padded_flash_attention( + q, k, v, + block_sizes=self.block_sizes, + kv_padding=self.kv_padding) + + def tree_flatten(self): + """Flattens the PyTree. + + Returns: + A tuple of dynamic children (none) and static auxiliary data. + """ + # All attributes are static, so they go into aux_data + aux_data = (self.block_sizes, self.kv_padding) + # No dynamic children + children = () + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the PyTree from static and dynamic data.""" + # Unpack the static data + block_sizes, kv_padding = aux_data + # No dynamic children to unpack + return cls(block_sizes, kv_padding) + +def make_dense_padded_attention(block_sizes: BlockSizes, kv_padding: int): + return DensePaddedAttention(block_sizes=block_sizes, kv_padding=kv_padding) + diff --git a/splash_attn_benchmark.py b/splash_attn_benchmark.py new file mode 100644 index 00000000..738f2381 --- /dev/null +++ b/splash_attn_benchmark.py @@ -0,0 +1,395 @@ +import functools +import math +import time +from typing import Optional, Callable, Tuple +import flax.linen as nn +from flax import nnx +import jax +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import PartitionSpec +import jax.numpy as jnp +from jax.experimental import shard_map +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +from einops import rearrange +from enum import Enum +from jax.sharding import Mesh +from jax.experimental import mesh_utils +from flax.linen import partitioning as nn_partitioning +from padded_flash_attn import make_dense_padded_attention + + +Mesh = jax.sharding.Mesh +AxisNames = tuple[str, ...] +BlockSizes = splash_attention_kernel.BlockSizes + +class Masking(Enum): + FULL = 1 + PADDING = 2 + SEGMENT = 3 + +def _reshape_heads_to_head_dim(tensor): + # takes a tensor of shape [b, h, s, d] and reshapes to [b, s, h * d] + # This is used to transform the output of flash attention back into the format of other attention outputs + b, h, s, d = tensor.shape + tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) + reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + +def _unflatten_heads(tensor, heads): + # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + # Transpose to ('batch', 'heads', 'length', 'kv') + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + return tensor + + +def _reshape_data_for_flash(tensor, heads): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. + """ + if tensor.ndim != 4: + tensor = _unflatten_heads(tensor, heads) + return tensor + +def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. + """ + tensor = _reshape_data_for_flash(tensor, heads) + + # Pad head_dim to 128 if less than that. + kv_size = tensor.shape[-1] + head_dim_pad = 0 + if kv_size < 128: + head_dim_pad = 128 - kv_size + + # Pad seq_len with sharding constraints. + seq_len = tensor.shape[2] + + # 1. First, pad seq_len to be a multiple of flash_block_size + rem = seq_len % flash_block_size + if rem != 0: + seq_len_padded_pre = seq_len + (flash_block_size - rem) + else: + seq_len_padded_pre = seq_len + + # 2. Ensure num_blocks is divisible by num_shards + num_blocks = seq_len_padded_pre // flash_block_size + if num_blocks % num_shards != 0: + num_blocks += num_shards - (num_blocks % num_shards) + + final_padded_len = num_blocks * flash_block_size + seq_len_pad = final_padded_len - seq_len + + if kv_size < 128 or seq_len_pad != 0: + npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) + tensor = jnp.pad(tensor, npad) + + return tensor, kv_size, seq_len + +NUM_LANES = 128 +def pad_kv_seq_to_lanes(tensor): + seq_len = tensor.shape[2] + if seq_len % NUM_LANES != 0: + seq_len_pad = seq_len + (NUM_LANES - (seq_len % NUM_LANES)) + npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, 0)) + tensor = jnp.pad(tensor, npad) + return tensor, seq_len + + +@functools.partial(jax.jit, static_argnames=("heads", + "mesh", + "axis_names_q", + "axis_names_kv", + "flash_block_sizes", + "dtype", + "attention_kernel", + "mask_type" + )) +def _tpu_flash_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + attention_kernel: str = "flash", + mask_type: Masking = Masking.FULL +) -> jax.Array: + + num_fsdp_shards = mesh.shape["fsdp"] + query = _reshape_data_for_flash(query, heads) + key = _reshape_data_for_flash(key, heads) + value = _reshape_data_for_flash(value, heads) + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + block_sizes = flash_block_sizes + @functools.partial( + shard_map.shard_map, + mesh=mesh, + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), + out_specs=q_axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv) + + q_padded_len = query.shape[2] + kv_padded_len = key.shape[2] + jax.debug.print("q_orig_len {q_orig_len}, padded_len: {q_padded_len}, kv_orig_len {kv_orig_len}, padded_len: {kv_padded_len}", + q_orig_len=query_seq_len, + q_padded_len=q_padded_len, + kv_orig_len=key_seq_len, + kv_padded_len=kv_padded_len, + ) + + if mask_type == Masking.FULL and attention_kernel != "dense_padded": + mask = splash_attention_mask.FullMask(_shape=(q_padded_len, kv_padded_len)) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + segment_ids = None + elif mask_type == Masking.PADDING and attention_kernel != "dense_padded": + padding_mask = splash_attention_mask.PaddingMask( + shape=(q_padded_len, kv_padded_len), + q_seq_len=query_seq_len, + kv_seq_len=key_seq_len + ) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(padding_mask,) * query.shape[1]) + segment_ids = None + elif mask_type == Masking.SEGMENT and attention_kernel != "dense_padded": + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + mask = splash_attention_mask.FullMask(_shape=(q_padded_len, kv_padded_len)) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + + # jax.debug.print("Is cross attention: {is_cross_attention}, q_padded_len: {q_padded_len}, kv_padded_len: {kv_padded_len}", is_cross_attention=is_cross_attention, q_padded_len=q_padded_len, kv_padded_len=kv_padded_len) + # make_splash_mha is wrapped around shardmap and seq and head is already + # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. + + + if attention_kernel == "flash": + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + save_residuals=True if attention_kernel == "ring" else False + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0) + if segment_ids: + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None), out_axes=0) + attention_output = vmapped_splash(query, key, value, segment_ids) + else: + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0), out_axes=0) + attention_output = vmapped_splash(query, key, value) + elif attention_kernel == "ring": + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + save_residuals=True if attention_kernel == "ring" else False + ) + if num_fsdp_shards > 1: + out, (lse,) = vmapped_splash(query, key, value, segment_ids) + m = lse.astype(jnp.float32) + l = jnp.exp(lse - m) + o = out.astype(jnp.float32) * l[..., None] + + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) + + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) + + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) + + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) + + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None + + initial_carry = (m, l, o, k1, v1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + + attention_output = o_final / l_final[..., None] + elif attention_kernel == "dense_padded": + padded_kv_len = key.shape[1] - key_seq_len + dense_padded_attention_kernel = make_dense_padded_attention(block_sizes=block_sizes, kv_padding=padded_kv_len) + vmapped_splash = jax.vmap(dense_padded_attention_kernel, in_axes=(0, 0, 0), out_axes=0) + attention_output, _ = vmapped_splash(query, key, value) + else: + raise ValueError(f"Unknown attention kernel: {attention_kernel}") + + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] + x = wrap_flash_attention(query, key, value) + x = _reshape_heads_to_head_dim(x) + + return x + +# MESH AXES +DATA = "data" +FSDP = "fsdp" +TENSOR = "tensor" +# LOGICAL AXES +BATCH = "activation_batch" +D_KV = "activation_kv" +ATTN_HEAD = "activation_attn_heads" +ATTN_Q_LENGTH = "activation_attn_q_length" +ATTN_KV_LENGTH = "activation_attn_kv_length" +# LOGICAL AXES mapping to qkv tensor axes +axis_names_q = (BATCH, ATTN_HEAD, ATTN_Q_LENGTH, D_KV) +axis_names_kv = (BATCH, ATTN_HEAD, ATTN_KV_LENGTH, D_KV) + +### LOGICAL AXES TO PHYSICAL AXES MAPPING ### +RING_ATTENTION_AXIS_RULES = [ + [BATCH, DATA], + [ATTN_HEAD, None], + [ATTN_Q_LENGTH, FSDP], + [ATTN_KV_LENGTH, FSDP], + [D_KV, None] + +] + +SEQUENCE_PARALLEL_AXIS_RULES = [ + [BATCH, DATA], + [ATTN_HEAD, None], + [ATTN_Q_LENGTH, FSDP], + [ATTN_KV_LENGTH, None], + [D_KV, None ] +] + +TENSOR_PARALLEL_AXIS_RULES = [ + [BATCH, DATA], + [ATTN_HEAD, TENSOR], + [ATTN_Q_LENGTH, None], + [ATTN_KV_LENGTH, None], + [D_KV, None] +] + + +def main(): + BQ = [3024] + BKV = [2048] + BKV_COMPUTE = [1024] + rng = jax.random.key(1) + query = jax.random.normal(rng,(2, 40, 75600, 128), dtype=jnp.bfloat16) + rng = jax.random.key(2) + key = jax.random.normal(rng,(2, 40, 75600, 128), dtype=jnp.bfloat16) + rng = jax.random.key(3) + value = jax.random.normal(rng,(2, 40, 75600, 128), dtype=jnp.bfloat16) + # query = jnp.ones((2, 4, 3024, 128)) + # key = jnp.ones((2, 4, 2048, 128)) + # value = jnp.ones((2, 4, 2048, 128)) + data=2 + fsdp=1 + tensor=4 + mesh_devices = mesh_utils.create_device_mesh((data, fsdp, tensor), allow_split_physical_axes=True) + mesh = Mesh(mesh_devices, ('data','fsdp','tensor')) + + for bq in BQ: + for bk in BKV: + for bk_compute in BKV_COMPUTE: + block_sizes = splash_attention_kernel.BlockSizes( + block_q=bq, + block_kv_compute=bk_compute, + block_kv=bk) + print(block_sizes) + for mask in Masking: + for attn in ["dense_padded","flash" ]: + if mask != Masking.FULL and attn == "dense_padded": + print("==========SKIP NON FULL MASK DENSE PADDED ATTN") + continue + with mesh, nn_partitioning.axis_rules(TENSOR_PARALLEL_AXIS_RULES): + print (f"==========CASE {bq} {bk} {bk_compute} mask {mask} attn {attn}==========") + print("==========COMPILE==========") + lhs = _tpu_flash_attention( + query, + key, + value, + heads=40, + mesh=mesh, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, + flash_block_sizes=block_sizes, + dtype=jnp.bfloat16, + attention_kernel=attn, + mask_type=mask + ) + jax.block_until_ready( + lhs + ) + rhs = _tpu_flash_attention( + query, + key, + value, + heads=40, + mesh=mesh, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, + flash_block_sizes=block_sizes, + dtype=jnp.bfloat16, + attention_kernel="flash", + mask_type=Masking.SEGMENT + ) + jax.block_until_ready( + rhs + ) + + allclose = jnp.allclose(lhs, rhs, rtol=1e-3, atol=1e-3) + mean_diff = jnp.mean(jnp.abs(lhs - rhs)) + print(f"==========All close {allclose} mean diff {mean_diff}==========") + start = time.perf_counter() + print("==========PROFILE==========") + jax.block_until_ready( + _tpu_flash_attention( + query, + key, + value, + heads=40, + mesh=mesh, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, + flash_block_sizes=block_sizes, + dtype=jnp.bfloat16, + attention_kernel=attn, + mask_type=mask + ) + ) + end = time.perf_counter() + print("==========RESULT========") + print(f"=========={end - start}s block {bq} {bk} {bk_compute} mask {mask} attn {attn}==========") + print("==========END========") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index f03864da..724e2313 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -33,7 +33,11 @@ BlockSizes = splash_attention_kernel.BlockSizes AxisNames = tuple[str, ...] - +# Physical axis names for device meshes. +DATA = "data" +FSDP = "fsdp" +TENSOR = "tensor" +# Logical axis names for model parameters and activations. BATCH = "activation_batch" LENGTH = "activation_length" KV_LENGTH = "activation_kv_length" @@ -44,4 +48,32 @@ KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" +# For setting self/cross attention independently in splash kernel +SELF_ATTN_HEAD = "activation_self_attn_heads" +SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" +SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" +CROSS_ATTN_HEAD = "activation_cross_attn_heads" +CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" +CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" + + WAN_MODEL = "Wan2.1" + +### Common axis rules for ring attention ### +RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, FSDP], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, FSDP], +] + +SEQUENCE_PARALLEL_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, None], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, None], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8149c829..8201c0b8 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -57,18 +57,18 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring -flash_min_seq_length: 4096 +flash_min_seq_length: 0 dropout: 0.1 flash_block_sizes: { - "block_q" : 1024, - "block_kv_compute" : 256, - "block_kv" : 1024, - "block_q_dkv" : 1024, - "block_kv_dkv" : 1024, - "block_kv_dkv_compute" : 256, - "block_q_dq" : 1024, - "block_kv_dq" : 1024 + "block_q" : 3024, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 3024, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 2048, + "block_q_dq" : 3024, + "block_kv_dq" : 2048 } # Use on v6e # flash_block_sizes: { @@ -77,10 +77,21 @@ flash_block_sizes: { # "block_kv" : 2048, # "block_q_dkv" : 3024, # "block_kv_dkv" : 2048, -# "block_kv_dkv_compute" : 2048, +# "block_kv_dkv_compute" : 1024, # "block_q_dq" : 3024, # "block_kv_dq" : 2048 # } +# Use on v5p +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 1024, +# "block_kv_dkv" : 3072, +# "block_kv_dkv_compute" : 256, +# "block_q_dq" : 1024, +# "block_kv_dq" : 3072 +# } # GroupNorm groups norm_num_groups: 32 @@ -141,8 +152,9 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], ['activation_length', 'fsdp'], - ['activation_heads', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -150,6 +162,7 @@ logical_axis_rules: [ ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], + ['conv_in', 'fsdp'], ['conv_out', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -271,7 +284,7 @@ flow_shift: 3.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 -fps: 24 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 451b2829..a85b1ffb 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -62,6 +62,14 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) +jax.config.update("jax_default_prng_impl", "unsafe_rbg") + # TF allocates extraneous GPU memory when using TFDS data + # this leads to CUDA OOMs. WAR for now is to hide GPUs from TF + # tf.config.set_visible_devices([], "GPU") +if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) def inference_generate_video(config, pipeline, filename_prefix=""): @@ -97,7 +105,6 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer - checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") pipeline = checkpoint_loader.load_checkpoint() if pipeline is None: diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6638e0f8..648e7fdb 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -495,14 +495,14 @@ def get_flash_block_sizes(config): flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=config.flash_block_sizes["block_q"], - block_kv_compute=config.flash_block_sizes["block_kv_compute"], - block_kv=config.flash_block_sizes["block_kv"], - block_q_dkv=config.flash_block_sizes["block_q_dkv"], - block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=config.flash_block_sizes["block_q_dq"], - block_kv_dq=config.flash_block_sizes["block_kv_dq"], + block_q=int(config.flash_block_sizes["block_q"]), + block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]), + block_kv=int(config.flash_block_sizes["block_kv"]), + block_q_dkv=config.flash_block_sizes.get("block_q_dkv"), + block_kv_dkv=config.flash_block_sizes.get("block_kv_dkv"), + block_kv_dkv_compute=config.flash_block_sizes.get("block_kv_dkv_compute"), + block_q_dq=config.flash_block_sizes.get("block_q_dq"), + block_kv_dq=config.flash_block_sizes.get("block_kv_dq"), ) return flash_block_sizes diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..b4bb5ed5 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -45,6 +45,13 @@ EMBED = common_types.EMBED Quant = quantizations.AqtQuantization +SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD +SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH +SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH +CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD +CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH +CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH + def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() @@ -184,7 +191,8 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: block_sizes = splash_attention_kernel.BlockSizes( @@ -439,7 +447,16 @@ def _apply_attention( ) elif attention_kernel == "flash": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + attention_kernel, ) elif attention_kernel == "ring": return _tpu_flash_attention( @@ -701,6 +718,7 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, + is_self_attention: bool = True, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -717,6 +735,13 @@ def __init__( self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + if is_self_attention: + axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV) + else: + axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) + self.attention_op = NNXAttentionOp( mesh=mesh, attention_kernel=attention_kernel, @@ -726,6 +751,8 @@ def __init__( use_memory_efficient_attention=use_memory_efficient_attention, split_head_dim=split_head_dim, float32_qk_product=False, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, dtype=dtype, diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 0226a859..77f35073 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -16,6 +16,7 @@ from typing import Tuple, List, Sequence, Union, Optional +import flax import jax import jax.numpy as jnp from flax import nnx @@ -27,7 +28,7 @@ BlockSizes = common_types.BlockSizes CACHE_T = 2 - +flax.config.update('flax_always_shard_variable', False) # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 48ed7b8e..5652ae89 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -282,6 +282,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=True, ) # 1. Cross-attention @@ -300,6 +301,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=False, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -338,6 +340,7 @@ def __call__( encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) # 1. Self-attention + # with jax.named_scope("SelfAttention"): norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype) attn_output = self.attn1( hidden_states=norm_hidden_states, @@ -349,13 +352,18 @@ def __call__( hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention + # with jax.named_scope("CrossAttention"): norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2( - hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, ) hidden_states = hidden_states + attn_output # 3. Feed-forward + # with jax.named_scope("FeedForward"): norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype) ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype) @@ -493,8 +501,8 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) - - hidden_states = self.patch_embedding(hidden_states) + with jax.named_scope("PatchEmbedding"): + hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( @@ -515,12 +523,13 @@ def scan_fn(carry, block): scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded ) initial_carry = (hidden_states, rngs) - final_carry, _ = nnx.scan( - rematted_block_forward, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) + with jax.named_scope("Transformer"): + final_carry, _ = nnx.scan( + rematted_block_forward, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + )(initial_carry, self.blocks) hidden_states, _ = final_carry diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index c78d8bae..6c01e5e1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -535,17 +535,18 @@ def __call__( prompt = [prompt] batch_size = len(prompt) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) + + with jax.named_scope("Encode-Prompt"): + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) num_channel_latents = self.transformer.config.in_channels - if latents is None: + if latents is None: latents = self.prepare_latents( batch_size=batch_size, vae_scale_factor_temporal=self.vae_scale_factor_temporal, @@ -554,7 +555,7 @@ def __call__( width=width, num_frames=num_frames, num_channels_latents=num_channel_latents, - ) + ) # # fusion.18 data_sharding = NamedSharding(self.mesh, P()) # Using global_batch_size_to_train_on so not to create more config variables diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3bb5bd13..14e7fcb3 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH +from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES def string_to_bool(s: str) -> bool: @@ -180,14 +180,22 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. if raw_keys["attention"] == "ring": + max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.") logical_axis_rules = list(raw_keys["logical_axis_rules"]) + max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") + new_rules = [] q_seq_sharding = (LENGTH, "fsdp") kv_seq_sharding = (KV_LENGTH, "fsdp") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) + for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: + if ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") + new_rules.append(ring_attention_axis_rule) + raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) + max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 3d1327c3..9602d17c 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest from flax import nnx from jax.sharding import Mesh - +from flax.linen import partitioning as nn_partitioning from .. import pyconfig from ..max_utils import (create_device_mesh, get_flash_block_sizes) from ..models.wan.transformers.transformer_wan import ( @@ -53,6 +53,18 @@ class WanTransformerTest(unittest.TestCase): def setUp(self): WanTransformerTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) + def test_rotary_pos_embed(self): batch_size = 1 @@ -70,18 +82,20 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) - dummy_output = layer(dummy_caption) - dummy_output.shape == (1, 512, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) def test_nnx_timestep_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) - dummy_output = layer(dummy_sample) - assert dummy_output.shape == (1, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) def test_fp32_layer_norm(self): key = jax.random.key(0) @@ -89,9 +103,10 @@ def test_fp32_layer_norm(self): batch_size = 1 dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) # expected same output shape with same dtype - layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) - dummy_output = layer(dummy_hidden_states) - assert dummy_output.shape == dummy_hidden_states.shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_time_text_embedding(self): @@ -102,20 +117,21 @@ def test_wan_time_text_embedding(self): time_freq_dim = 256 time_proj_dim = 30720 text_embed_dim = 4096 - layer = WanTimeTextImageEmbedding( - rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) - dummy_timestep = jnp.ones(batch_size) + dummy_timestep = jnp.ones(batch_size) - encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) - dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( - dummy_timestep, dummy_encoder_hidden_states - ) - assert temb.shape == (batch_size, dim) - assert timestep_proj.shape == (batch_size, time_proj_dim) - assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) def test_wan_block(self): key = jax.random.key(0) @@ -163,20 +179,19 @@ def test_wan_block(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) - - wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - with mesh: + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape @@ -209,40 +224,39 @@ def test_wan_attention(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 query_dim = 5120 - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - - dummy_hidden_states_shape = (batch_size, 75600, query_dim) - - dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - with mesh: - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) - assert dummy_output.shape == dummy_hidden_states_shape - - # dot product - try: + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): attention = FlaxWanAttention( rngs=rngs, query_dim=query_dim, heads=40, dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, + attention_kernel="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, ) - except NotImplementedError: - pass + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError: + pass @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): @@ -272,7 +286,8 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 num_layers = 1 - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + with nn_partitioning.axis_rules(config.logical_axis_rules): + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 66d8dce9..56fd0d9c 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from flax import nnx +from flax.linen import partitioning as nn_partitioning from jax.sharding import Mesh from .. import pyconfig from ..max_utils import ( @@ -161,6 +162,17 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -210,12 +222,13 @@ def test_zero_padded_conv(self): output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) - dummy_input = jnp.ones(input_shape) - dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) - output = model(dummy_input) - output = jnp.transpose(output, (0, 3, 1, 2)) - assert output.shape == (1, 96, 240, 360) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size = 1 @@ -247,13 +260,13 @@ def test_wan_resample(self): torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - - wan_resample = WanResample(dim, mode=mode, rngs=rngs) - # channels is always last here - input_shape = (batch, t, h, w, dim) - dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, w // 2, dim) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_resample = WanResample(dim, mode=mode, rngs=rngs) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) @@ -284,28 +297,29 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization, + mesh=mesh, + ) - # --- Test Case 1: No Cache --- - output_no_cache = causal_conv_layer(dummy_input) - assert output_no_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 2: With Cache --- - output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) - assert output_with_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) - dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) - output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) - assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) def test_wan_residual(self): key = jax.random.key(0) @@ -329,21 +343,20 @@ def test_wan_residual(self): dim = 96 input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape - - # --- Test Case 1: different in/out dim --- - in_dim = 96 - out_dim = 196 - expected_output_shape = (batch, t, height, width, out_dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + # --- Test Case 1: different in/out dim --- + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape def test_wan_attention(self): key = jax.random.key(0) @@ -354,10 +367,11 @@ def test_wan_attention(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) - dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) - assert output.shape == input_shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_midblock(self): key = jax.random.key(0) @@ -378,10 +392,11 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) - assert output.shape == input_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape def test_wan_decode(self): key = jax.random.key(0) @@ -402,30 +417,31 @@ def test_wan_decode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - t = 13 - channels = 16 - height = 60 - width = 90 - input_shape = (batch, t, height, width, channels) - input = jnp.ones(input_shape) - - latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) - latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) - input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input, feat_cache=vae_cache) - assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): key = jax.random.key(0) @@ -446,26 +462,27 @@ def test_wan_encode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - channels = 3 - t = 49 - height = 480 - width = 720 - input_shape = (batch, channels, t, height, width) - input = jnp.ones(input_shape) - output = wan_vae.encode(input, feat_cache=vae_cache) - assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_vae.encode(input, feat_cache=vae_cache) + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) def test_load_checkpoint(self): def vae_encode(video, wan_vae, vae_cache, key): @@ -485,9 +502,9 @@ def vae_encode(video, wan_vae, vae_cache, key): config = pyconfig.config devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - - wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) - vae_cache = AutoencoderKLWanCache(wan_vae) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) + vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path)