From 9a965e6c048e4c4a2f6988e8d2303d08113385f3 Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Wed, 14 Aug 2024 23:58:31 -0700 Subject: [PATCH] update --- MaxText/attentions_test.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/MaxText/attentions_test.py b/MaxText/attentions_test.py index fc75aa5b5..bc3a62aae 100644 --- a/MaxText/attentions_test.py +++ b/MaxText/attentions_test.py @@ -46,6 +46,16 @@ def calculate_attention_forward_tflops_per_device(config): return attention_flops +def create_random_global_array(shape, sharding, dtype): + arr = np.random.normal(size=shape) + random_global_array = jax.make_array_from_single_device_arrays( + shape, + sharding, + [jax.device_put(arr[index], d) for d, index in sharding.addressable_devices_indices_map(shape).items()], + ).astype(dtype) + return random_global_array + + def get_train_iter(config, mesh): """Generates an infinite stream of random query, key, value batches.""" query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) @@ -53,33 +63,28 @@ def get_train_iter(config, mesh): value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM) decoder_segment_ids_axis_names: AxisNames = (KV_BATCH, LENGTH) - rng = jax.random.PRNGKey(0) - while True: - query = jax.random.normal( - rng, + query = create_random_global_array( shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_query_heads, config.head_dim), + sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(query_axis_names)), dtype=config.dtype, ) - key = jax.random.normal( - rng, + key = create_random_global_array( shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_kv_heads, config.head_dim), + sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(key_axis_names)), dtype=config.dtype, ) - value = jax.random.normal( - rng, + value = create_random_global_array( shape=(config.global_batch_size_to_train_on, config.max_target_length, config.num_kv_heads, config.head_dim), + sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(value_axis_names)), dtype=config.dtype, ) - decoder_segment_ids = jax.random.randint(rng, (config.global_batch_size_to_train_on, config.max_target_length), 0, config.vocab_size) - - # Shard the data across the mesh - query = jax.device_put(query, NamedSharding(mesh, nn.logical_to_mesh_axes(query_axis_names))) - key = jax.device_put(key, NamedSharding(mesh, nn.logical_to_mesh_axes(key_axis_names))) - value = jax.device_put(value, NamedSharding(mesh, nn.logical_to_mesh_axes(value_axis_names))) - decoder_segment_ids = jax.device_put(decoder_segment_ids, NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names))) - + decoder_segment_ids = create_random_global_array( + shape=(config.global_batch_size_to_train_on, config.max_target_length), + sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names)), + dtype=config.dtype, + ) yield query, key, value, decoder_segment_ids