Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Aug 15, 2024
1 parent f5f459d commit 7d9e0e4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions MaxText/attentions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flax.linen import partitioning as nn_partitioning
from layers import quantizations
import math
import jax.numpy as jnp


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -47,13 +48,12 @@ def calculate_attention_forward_tflops_per_device(config):


def create_random_global_array(rng, global_shape, sharding, dtype):

local_flat_tensor_shape = math.prod(global_shape) // jax.device_count()
local_flat_tensor = jax.random.normal(rng, shape=local_flat_tensor_shape, dtype=dtype)
local_tensor_shape = sharding.shard_shape(global_shape)
local_tensor = jax.random.normal(rng, shape=local_tensor_shape, dtype=jnp.float32)
random_global_array = jax.make_array_from_single_device_arrays(
global_shape,
sharding,
[jax.device_put(local_flat_tensor, d) for d, index in sharding.addressable_devices_indices_map(global_shape).items()],
[jax.device_put(local_tensor, d) for d, index in sharding.addressable_devices_indices_map(global_shape).items()],
).astype(dtype)
return random_global_array

Expand Down Expand Up @@ -89,7 +89,7 @@ def get_train_iter(config, mesh):
rng,
global_shape=(config.global_batch_size_to_train_on, config.max_target_length),
sharding=NamedSharding(mesh, nn.logical_to_mesh_axes(decoder_segment_ids_axis_names)),
dtype=config.dtype,
dtype=jnp.int32,
)
yield query, key, value, decoder_segment_ids

Expand Down

0 comments on commit 7d9e0e4

Please sign in to comment.