Skip to content

Commit

Permalink
matching pvpython3 MaxText/train.py MaxText/configs/base.yml run_name…
Browse files Browse the repository at this point in the history
…=mattdavidow-train-base base_output_directory=gs://maxtext-experiments-multipod dataset_path=gs://max-datasets-rogue steps=5 enable_checkpointing=False global_parameter_scale=1 per_device_batch_size=4 int8_training=True use_dqdg=False fwd_int8=False dlhs_int8=False drhs_int8=False fwd_int8_qk=False dlhs_int8_qk=False drhs_int8_qk=False fwd_int8_pv=True dlhs_int8_pv=True drhs_int8_pv=False fwd_int8_logits=False dlhs_int8_logits=False drhs_int8_logits=False quantize_logits=False aqt_use_dummy_static_bound=False aqt_use_fwd_quant=False aqt_rng_type=jax.uniformpython3 MaxText/train.py MaxText/configs/base.yml run_name=mattdavidow-train-base base_output_directory=gs://maxtext-experiments-multipod dataset_path=gs://max-datasets-rogue steps=5 enable_checkpointing=False global_parameter_scale=1 per_device_batch_size=4 int8_training=True use_dqdg=False fwd_int8=False dlhs_int8=False drhs_int8=False fwd_int8_qk=False dlhs_int8_qk=False drhs_int8_qk=False fwd_int8_pv=True dlhs_int8_pv=True drhs_int8_pv=False fwd_int8_logits=False dlhs_int8_logits=False drhs_int8_logits=False quantize_logits=False aqt_use_dummy_static_bound=False aqt_use_fwd_quant=False aqt_rng_type=jax.uniform
  • Loading branch information
gobbleturk committed Nov 7, 2023
1 parent 49454de commit 480820c
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions MaxText/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ def dot_product_attention(query: Array,
aqt_dot_general = aqt.make_dot_general(aqt_cfg)
context = aqt.Context(key=aqt_rng, train_step=None)
aqt_dot_general = functools.partial(aqt_dot_general, context=context)
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key, _dot_general=aqt_dot_general)
if cfg.fwd_int8_qk:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key, _dot_general=aqt_dot_general)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)


# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
Expand Down Expand Up @@ -208,7 +212,11 @@ def dot_product_attention(query: Array,
aqt_dot_general = aqt.make_dot_general(aqt_cfg)
context = aqt.Context(key=aqt_rng, train_step=None)
aqt_dot_general = functools.partial(aqt_dot_general, context=context)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value, _dot_general=aqt_dot_general)
if cfg.fwd_int8_pv:
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value, _dot_general=aqt_dot_general)
else:
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)



dynamic_vector_slice_in_dim = jax.vmap(
Expand Down Expand Up @@ -300,8 +308,11 @@ def __call__(self, inputs: Array) -> Array:
aqt_cfg = get_aqt_cfg(cfg.fwd_int8, cfg.dlhs_int8, cfg.drhs_int8)
aqt_dot_general = aqt.make_dot_general(aqt_cfg)
context = aqt.Context(key=aqt_key, train_step=None)
if cfg.fwd_int8:
return aqt_dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), context=context)
else:
return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))

return aqt_dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), context=context)

def _convert_to_activation_function(
fn_or_string: Union[str, Callable]) -> Callable:
Expand Down Expand Up @@ -727,13 +738,18 @@ def attend(self, query: Array) -> Array:
if not self.config.int8_training:
return maxtext_dot(query, jnp.asarray(self.embedding, dtype).T)
else:
cfg = self.config
aqt_cfg = get_aqt_cfg(cfg.fwd_int8_logits, cfg.dlhs_int8_logits, cfg.drhs_int8_logits)
aqt_dot_general = aqt.make_dot_general(aqt_cfg)
aqt_key = self.make_rng('aqt')
context = aqt.Context(key=aqt_key, train_step=None)
aqt_dot_general = functools.partial(aqt_dot_general, context=context)
dtype = jnp.float32 if query.dtype==jnp.float32 or self.embedding.dtype==jnp.float32 else jnp.bfloat16
return maxtext_dot(jnp.asarray(query, dtype), jnp.asarray(self.embedding, dtype).T, aqt_dot_general)
if cfg.fwd_int8_logits:
return maxtext_dot(jnp.asarray(query, dtype), jnp.asarray(self.embedding, dtype).T, aqt_dot_general)
else:
return maxtext_dot(query, jnp.asarray(self.embedding, dtype).T)


class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
Expand Down

0 comments on commit 480820c

Please sign in to comment.