Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Aug 15, 2024
1 parent d643ac0 commit 9a965e6
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions MaxText/attentions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,40 +46,45 @@ def calculate_attention_forward_tflops_per_device(config):
return attention_flops


def create_random_global_array(shape, sharding, dtype):
arr = np.random.normal(size=shape)
random_global_array = jax.make_array_from_single_device_arrays(
shape,
sharding,
[jax.device_put(arr[index], d) for d, index in sharding.addressable_devices_indices_map(shape).items()],
).astype(dtype)
return random_global_array


def get_train_iter(config, mesh):
"""Generates an infinite stream of random query, key, value batches."""
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
decoder_segment_ids_axis_names: AxisNames = (KV_BATCH, LENGTH)

rng = jax.random.PRNGKey(0)

while True:
query = jax.random.normal(
rng,
query = create_random_global_array(
shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_query_heads, config.head_dim),
sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(query_axis_names)),
dtype=config.dtype,
)
key = jax.random.normal(
rng,
key = create_random_global_array(
shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_kv_heads, config.head_dim),
sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(key_axis_names)),
dtype=config.dtype,
)
value = jax.random.normal(
rng,
value = create_random_global_array(
shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_kv_heads, config.head_dim),
sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(value_axis_names)),
dtype=config.dtype,
)
decoder_segment_ids = jax.random.randint(rng, (config.global_batch_size_to_train_on, config.max_target_length), 0, config.vocab_size)

# Shard the data across the mesh
query = jax.device_put(query, NamedSharding(mesh, nn.logical_to_mesh_axes(query_axis_names)))
key = jax.device_put(key, NamedSharding(mesh, nn.logical_to_mesh_axes(key_axis_names)))
value = jax.device_put(value, NamedSharding(mesh, nn.logical_to_mesh_axes(value_axis_names)))
decoder_segment_ids = jax.device_put(decoder_segment_ids, NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names)))


decoder_segment_ids = create_random_global_array(
shape=(config.global_batch_size_to_train_on, config.max_target_length),
sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names)),
dtype=config.dtype,
)
yield query, key, value, decoder_segment_ids


Expand Down

0 comments on commit 9a965e6

Please sign in to comment.