Skip to content

Commit

Permalink
06_12
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 13, 2024
1 parent 120ec39 commit 7bf2f4a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 84 deletions.
2 changes: 1 addition & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ adam_eps_root: 0. # A small constant applied to denominator inside the square ro
adam_weight_decay: 0.1 # AdamW Weight decay

# Stack trace parameters
collect_stack_trace: True
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds.

Expand Down
148 changes: 79 additions & 69 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
import max_logging
from jax.sharding import PartitionSpec

try:
# from jax.experimental.pallas.ops.tpu import megablox as mblx
Expand Down Expand Up @@ -340,7 +341,9 @@ def permute(self, inputs, gate_logits, emb_dim):

# reshape inputs (batch, sequence, emb) to 2D
inputs_2d = jnp.reshape(inputs, (-1, emb_dim))
# print('inputs_2d', inputs_2d.shape)
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
# print('gate_logits', gate_logits.shape)
weights = jax.nn.softmax(weights.astype(self.weight_dtype), axis=-1).astype(self.dtype)
flatten_selected_experts = jnp.ravel(selected_experts)
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
Expand All @@ -355,34 +358,41 @@ def permute(self, inputs, gate_logits, emb_dim):
def unpermute(self, intermediate, inputs, sorted_selected_experts, weights):
"""Unpermute tokens to original order and combine weights."""

# print("unpermute:...")
# print(f"intermediate: {intermediate.shape}")
# print(f"inputs: {inputs.shape}")
unsort_output = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0)
flatten_weights = jnp.ravel(weights)
combined_output = jnp.multiply(unsort_output, flatten_weights[:, None])
# print(f"combined_output: {combined_output.shape}")
groups = jnp.reshape(combined_output, (-1, self.num_experts_per_tok, combined_output.shape[1]))
return jnp.sum(groups, axis=1).reshape(inputs.shape).astype(self.dtype)
# print(f"groups: {groups.shape}")
return jnp.sum(groups, axis=1).reshape(-1, self.config.max_target_length, self.config.emb_dim).astype(self.dtype)

def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo_kernel):
def call_gmm(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): currently megablox works well on single host, and
# will add sharding properly to improve performance.
# kernel_axes = ('exp', 'embed', 'mlp')
# wo_kernel_axes = ('exp', 'mlp', 'embed')

tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
# tile_size = None
# tile_size = (4096, 128, 128)
# tile_size = (512, 512, 512)
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(nn.logical_to_mesh_axes(('test', None))),
(nn.logical_to_mesh_axes((None, None, None))),
(nn.logical_to_mesh_axes((None,))),
),
out_specs=(nn.logical_to_mesh_axes(('test', None))),
check_rep=False,
)
# tile_size = (self.config.tile_size_0, self.config.tile_size_1, self.config.tile_size_2)
tile_size = None
# @functools.partial(
# shard_map.shard_map,
# mesh=self.mesh,
# in_specs=(
# (nn.logical_to_mesh_axes(('test', None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None,))),
# ),
# out_specs=(nn.logical_to_mesh_axes(('test', None))),
# check_rep=False,
# )
def gmm(inputs, kernel, group_sizes):
# print(f"inside")
# print(f"inputs: {inputs.shape}")
# print(f"kernel: {kernel.shape}")
# print(f"group_size: {group_sizes}")
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
pad_length = tile_size[0] if tile_size else 512
Expand All @@ -403,48 +413,43 @@ def gmm(inputs, kernel, group_sizes):
output = output[:hs_shape[0]]
return output

# from jax.sharding import PartitionSpec
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
# w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)

layer_w0 = gmm(inputs, w0_kernel, group_sizes)
layer_w1 = gmm(inputs, w1_kernel, group_sizes)
layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
output = gmm(intermediate_layer, wo_kernel, group_sizes)
return output
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(PartitionSpec('fsdp', None, None),
PartitionSpec('fsdp', None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
PartitionSpec(None, None, None),
)),
out_specs=PartitionSpec('fsdp', None, None),
check_rep=False,
)
def inner_fn(x, logits, w0, w1, wo):
x, sorted_selected_experts, weights, group_sizes = self.permute(x,logits,config.emb_dim)
# breakpoint()
layer_w0 = gmm(x, w0, group_sizes)
layer_w1 = gmm(x, w1, group_sizes)
layer_act = _convert_to_activation_function(config.mlp_activations[0])(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
intermediate_output = gmm(intermediate_layer, wo, group_sizes)
# print(f"intermediate_output.shape: {intermediate_output.shape}")
# print(f"x.shape: {x.shape}")
output = self.unpermute(intermediate_output,
x,
sorted_selected_experts,
weights)
# print(f"unpermute: {output.shape}")
return output
# print(f"inner_fn inputs: {inputs.shape}")
return inner_fn(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)

# inputs: (batch * selected_exp * sequence, emb_dim) - (262144, 4096)
# w0_kernel: (num_exp, emb_dim, mlp) -> (8, 4096, 14336)
# w1_kernel: (num_exp, emb_dim, mlp)
# o_kernel: (num_exp, mlp, emb_dim) - > (8, 14336, 4096)

# @functools.partial(
# shard_map.shard_map,
# mesh=self.mesh,
# in_specs=(
# (nn.logical_to_mesh_axes(('test', None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None, None, None))),
# (nn.logical_to_mesh_axes((None,))),
# ),
# out_specs=(nn.logical_to_mesh_axes(('test', None))),
# check_rep=False,
# )
# def inner_fn(x, w0, w1, wo, gs):
# tile_size = (4096, 128, 128)
# layer_w0 = gmm(x, w0, gs, tile_size)
# layer_w1 = gmm(x, w1, gs, tile_size)
# layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
# intermediate_layer = jnp.multiply(layer_act, layer_w1)
# output = gmm(intermediate_layer, wo, gs, (tile_size[0], tile_size[2], tile_size[1]))
# # breakpoint()
# return output

# output = inner_fn(inputs, w0_kernel, w1_kernel, wo_kernel, group_sizes)
# return output

@nn.compact
def __call__(self, inputs):
cfg = self.config
Expand Down Expand Up @@ -472,23 +477,28 @@ def __call__(self, inputs):

if cfg.megablox:
max_logging.log("Running MoE megablox implementation.")
sorted_hidden_states, sorted_selected_experts, weights, group_sizes = self.permute(inputs,
gate_logits,
cfg.emb_dim)
from jax.sharding import PartitionSpec
replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)
return self.call_gmm(inputs, gate_logits, cfg, w0_kernel, w1_kernel, wo_kernel)
# sorted_hidden_states, sorted_selected_experts, weights, group_sizes = self.permute(inputs,
# gate_logits,
# cfg.emb_dim)
# from jax.sharding import PartitionSpec
# replicated_sharding = jax.sharding.NamedSharding(self.mesh, PartitionSpec(None))
# w0_kernel, w1_kernel, wo_kernel = jax.device_put((w0_kernel, w1_kernel, wo_kernel), device=replicated_sharding)

intermediate_output = self.call_gmm(sorted_hidden_states,
group_sizes,
cfg.mlp_activations[0],
w0_kernel,
w1_kernel,
wo_kernel)
output = self.unpermute(intermediate_output,
inputs,
sorted_selected_experts,
weights)
# print("before")
# print(f"sorted_hidden_states: {sorted_hidden_states.shape}")
# print(f"group_sizes: {group_sizes}")
# print(f"w0_kernel: {w0_kernel.shape}")
# intermediate_output = self.call_gmm(sorted_hidden_states,
# group_sizes,
# cfg.mlp_activations[0],
# w0_kernel,
# w1_kernel,
# wo_kernel)
# output = self.unpermute(intermediate_output,
# inputs,
# sorted_selected_experts,
# weights)
else:
max_logging.log("Running MoE matmul implementation.")
with jax.named_scope("wi_0"):
Expand Down
29 changes: 15 additions & 14 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,23 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
fsdp_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec('fsdp'))
replicated_sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None))
# moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# hidden_states = jax.device_put(hidden_states, device=fsdp_sharding)

hidden_states = nn.with_logical_constraint(
hidden_states, ('activation_batch', 'activation_length', 'activation_embed')
)
hidden_states = jax.device_put(hidden_states, device=fsdp_sharding)

#hidden_states = nn.with_logical_constraint(
# hidden_states, ('activation_batch', 'activation_length', 'activation_embed')
# )
print('hidden states shape', hidden_states.shape)
rng = jax.random.PRNGKey(40)
moe_variables = model.init(rng, jax.random.normal(rng, (int(cfg.per_device_batch_size),
cfg.max_target_length,
cfg.base_emb_dim)))
#moe_variables = model.init(rng, jax.random.normal(rng, (int(cfg.per_device_batch_size) * 4 ,
# cfg.max_target_length,
# cfg.base_emb_dim)))
moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# breakpoint()
# jax.debug.visualize_array_sharding(moe_variables['params']['gate']['kernel'].value)

time.simple_timeit(jax.jit(model.apply), moe_variables, hidden_states, tries=10, task="matmul")
output = jax.jit(model.apply)(moe_variables, hidden_states)
# output = model.apply(moe_variables, hidden_states)
# output = jax.jit(model.apply)(moe_variables, hidden_states)
output = model.apply(moe_variables, hidden_states)
return output


Expand All @@ -204,6 +204,7 @@ class MoeTest(unittest.TestCase):
def setUp(self):
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(
[None, 'configs/base.yml'],
run_name='test',
Expand All @@ -214,7 +215,7 @@ def setUp(self):
moe_matmul=True,
megablox=True,
ici_fsdp_parallelism=4,
per_device_batch_size=16,
per_device_batch_size=8,
dataset_type='synthetic',
attention='flash',
max_target_length=4096,
Expand All @@ -223,7 +224,7 @@ def setUp(self):
self.cfg = pyconfig.config
self.rng = jax.random.PRNGKey(42)

self.hidden_states = jax.random.uniform(self.rng, (int(self.cfg.per_device_batch_size),
self.hidden_states = jax.random.uniform(self.rng, (int(self.cfg.per_device_batch_size) * 4,
self.cfg.max_target_length,
self.cfg.base_emb_dim), dtype=self.cfg.dtype)
# print(f"{self.hidden_states.shape}=")
Expand All @@ -235,8 +236,8 @@ def setUp(self):
def test_moe_block(self):
variables, expected_output = get_expected_output(self.rng, self.hidden_states, self.cfg)
actual_output = get_moe_output(variables, self.hidden_states, self.cfg, self.mesh)
# print("expected_output", expected_output)
# print("actual_output", actual_output)
print("expected_output", expected_output.shape)
print("actual_output", actual_output.shape)
# breakpoint()
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))

Expand Down

0 comments on commit 7bf2f4a

Please sign in to comment.