Skip to content

Commit

Permalink
06_12
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jun 12, 2024
1 parent 5c1d1bb commit 120ec39
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 25 deletions.
8 changes: 4 additions & 4 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ adam_eps_root: 0. # A small constant applied to denominator inside the square ro
adam_weight_decay: 0.1 # AdamW Weight decay

# Stack trace parameters
collect_stack_trace: False
collect_stack_trace: True
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds.

Expand Down Expand Up @@ -304,6 +304,6 @@ enable_checkpoint_standard_logger: False
# Single-controller
enable_single_controller: False

tile_size_0: 4096
tile_size_1: 256
tile_size_2: 256
tile_size_0: 512
tile_size_1: 512
tile_size_2: 512
6 changes: 3 additions & 3 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_expected_output(rng, hidden_states, cfg):

# print("get_expected_output variables", variables)
# breakpoint()
time.simple_timeit(jax.jit(model.apply), variables, hidden_states, tries=10, task="loop")
# time.simple_timeit(jax.jit(model.apply), variables, hidden_states, tries=10, task="loop")

output = jax.jit(model.apply)(variables, hidden_states)
return variables, output
Expand Down Expand Up @@ -191,7 +191,7 @@ def get_moe_output(variables, hidden_states, cfg, mesh):
cfg.base_emb_dim)))
moe_variables = jax.device_put(moe_variables, device=fsdp_sharding)
# breakpoint()
jax.debug.visualize_array_sharding(moe_variables['params']['gate']['kernel'].value)
# jax.debug.visualize_array_sharding(moe_variables['params']['gate']['kernel'].value)

time.simple_timeit(jax.jit(model.apply), moe_variables, hidden_states, tries=10, task="matmul")
output = jax.jit(model.apply)(moe_variables, hidden_states)
Expand All @@ -214,7 +214,7 @@ def setUp(self):
moe_matmul=True,
megablox=True,
ici_fsdp_parallelism=4,
per_device_batch_size=4,
per_device_batch_size=16,
dataset_type='synthetic',
attention='flash',
max_target_length=4096,
Expand Down
54 changes: 54 additions & 0 deletions MaxText/tests/train_smoke_test_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Smoke test """
import os
import unittest
from train import main as train_main
from absl.testing import absltest


class Train(unittest.TestCase):
"""Smoke test G3 only"""

def test_tiny_config(self):
test_tmpdir = os.environ.get("TEST_TMPDIR")
train_main([
None,
"third_party/py/maxtext/configs/base.yml",
f"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_test",
r"dataset_path=gs://maxtext-dataset",
r"tokenizer_path=gs://ranran-multipod-dev/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral",
"per_device_batch_size=8",
"max_target_length=4096",
"dataset_type=synthetic",
"skip_first_n_steps_for_profiler=5",
"steps=10",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"enable_checkpointing=False",
"model_name=mixtral-test",
"ici_fsdp_parallelism=4",
"moe_matmul=True",
"megablox=True",
"attention=flash",
"profiler=xplane",
])


if __name__ == "__main__":
absltest.main()
36 changes: 18 additions & 18 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
# if config.metrics_file:
# max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
# if config.gcs_metrics and jax.process_index() == 0:
# running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)

_buffered_step = step
_buffered_metrics = metrics
Expand All @@ -133,11 +133,11 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
with jax.spmd_mode("allow_all"):
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
# if jax.process_index() == 0:
# for metric_name in metrics.get("scalar", []):
# writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
# for metric_name in metrics.get("scalars", []):
# writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0

Expand All @@ -147,9 +147,9 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()
# if full_log and jax.process_index() == 0:
# max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
# writer.flush()


def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None):
Expand Down Expand Up @@ -339,7 +339,7 @@ def setup_mesh_and_model(config):
"""

init_rng = random.PRNGKey(config.init_weights_seed)
writer = max_utils.initialize_summary_writer(config)
# writer = max_utils.initialize_summary_writer(config)
logger = checkpointing.setup_checkpoint_logger(config)
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
config.checkpoint_dir,
Expand All @@ -358,7 +358,7 @@ def setup_mesh_and_model(config):
model = Transformer(config, mesh, quant=quant)
learning_rate_schedule = max_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx
return init_rng, None, checkpoint_manager, mesh, model, learning_rate_schedule, tx


def setup_train_loop(config):
Expand Down Expand Up @@ -451,9 +451,9 @@ def train_loop(config, state=None):
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
max_utils.add_config_to_summary_writer(config, writer)
# max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
# max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
# max_utils.add_config_to_summary_writer(config, writer)

# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
if config.compiled_trainstep_file != "":
Expand Down Expand Up @@ -543,7 +543,7 @@ def train_loop(config, state=None):
if checkpoint_manager is not None:
checkpoint_manager.wait_until_finished()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
max_utils.close_summary_writer(writer)
# max_utils.close_summary_writer(writer)
record_goodput(recorder, config, job_end=True)
return state

Expand Down

0 comments on commit 120ec39

Please sign in to comment.