From 7d9e0e48eed5f0b29cc3905243d64b513af800f2 Mon Sep 17 00:00:00 2001 From: ZhiyuLi-goog Date: Thu, 15 Aug 2024 00:39:28 -0700 Subject: [PATCH] update --- MaxText/attentions_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MaxText/attentions_test.py b/MaxText/attentions_test.py index 4c8e9fc10..7fda8eca2 100644 --- a/MaxText/attentions_test.py +++ b/MaxText/attentions_test.py @@ -17,6 +17,7 @@ from flax.linen import partitioning as nn_partitioning from layers import quantizations import math +import jax.numpy as jnp Mesh = jax.sharding.Mesh @@ -47,13 +48,12 @@ def calculate_attention_forward_tflops_per_device(config): def create_random_global_array(rng, global_shape, sharding, dtype): - - local_flat_tensor_shape = math.prod(global_shape) // jax.device_count() - local_flat_tensor = jax.random.normal(rng, shape=local_flat_tensor_shape, dtype=dtype) + local_tensor_shape = sharding.shard_shape(global_shape) + local_tensor = jax.random.normal(rng, shape=local_tensor_shape, dtype=jnp.float32) random_global_array = jax.make_array_from_single_device_arrays( global_shape, sharding, - [jax.device_put(local_flat_tensor, d) for d, index in sharding.addressable_devices_indices_map(global_shape).items()], + [jax.device_put(local_tensor, d) for d, index in sharding.addressable_devices_indices_map(global_shape).items()], ).astype(dtype) return random_global_array @@ -89,7 +89,7 @@ def get_train_iter(config, mesh): rng, global_shape=(config.global_batch_size_to_train_on, config.max_target_length), sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names)), - dtype=config.dtype, + dtype=jnp.int32, ) yield query, key, value, decoder_segment_ids