diff --git a/README.md b/README.md index 40a0b4a9..0c47e2d3 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Our guiding principles when building `torchtitan`: [![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!") -### Our torchtitan paper on arXiv +### torchtitan paper on arXiv [![arXiv](https://img.shields.io/badge/arXiv-2410.06511-b31b1b.svg?style=plastic)](https://arxiv.org/abs/2410.06511) @@ -61,7 +61,7 @@ You may want to see how the model is defined or how parallelism techniques are a 7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) 8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel 9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) -10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc. +10. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc. 11. All options easily configured via [toml files](train_configs/) We report our [Performance](docs/performance.md) verified on 64/128 GPUs. @@ -121,14 +121,6 @@ If your gpu count per node is not 8, adjust: in the SBATCH command section. - -## Debugging -### Troubleshooting jobs that timeout -If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs. -When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory. -To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html). - - ## License This code is made available under [BSD 3 license](./LICENSE). However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models, data, etc. diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 72e6a021..a4f2ecc8 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -14,7 +14,7 @@ python3 scripts/convert_llama_to_dcp.py This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune. -## Steps +### Steps 1. ENABLE CHECKPOINTING In your torchtitan training config, ensure that `enable_checkpoint` is set to True. ``` diff --git a/docs/converging.md b/docs/converging.md index 57a8450b..9a5abf58 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -1,6 +1,6 @@ This note clarifies the recommended practices to follow when testing the loss converging of a new feature. -#### Disclaimers +### Disclaimers 1. We assume the vanilla 1D FSDP to be “correct”, and would serve as the baseline for comparisons. The correctness of FSDP can be verified by comparing with DDP on small models, which has been widely adopted and believed to be correct. 2. The focus is on the correctness of new distributed training techniques. For a new model size / architecture, the demonstration of loss-converging is not in the scope of this note. @@ -54,7 +54,7 @@ Remarks | 2D (MN GPUs)
e.g. M=8 | FSDP N, CP M | to verify CP with a larger degree | -#### Test results +### Test results (TBA) [^1]: Model initialization in a sharded setting can hardly match that in a single-device setting (or a differently sharded setting), because each time a random operator is called, the underlying RNG state offset is advanced by a quantized amount, often not aligned with the amount of randomness needed, thus “wasting” different amount of randomness on differently sharded settings. diff --git a/docs/datasets.md b/docs/datasets.md index e13da2dd..06739353 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -1,6 +1,6 @@ -# Custom Datasets in TorchTitan +# Custom Datasets in torchtitan -TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example. +`torchtitan` is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example. ## Quick Start Locate the dataset configuration file: @@ -60,7 +60,7 @@ In your training configuration file (`.toml`), set your dataset: dataset = "wikipedia" ``` -That's it! Your custom dataset is now ready to use with TorchTitan. +That's it! Your custom dataset is now ready to use with `torchtitan`. ## Key Points - The DatasetConfig contains all necessary components for a dataset: diff --git a/docs/memory_profiler.md b/docs/debugging.md similarity index 61% rename from docs/memory_profiler.md rename to docs/debugging.md index d73ecaf9..4170c2af 100644 --- a/docs/memory_profiler.md +++ b/docs/debugging.md @@ -11,3 +11,10 @@ CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --profiling. You cab find the saved pickle files in your output folder. To visualize a snapshot file, you can drag and drop it to . To learn more details on memory profiling, please visit this [tutorial](https://pytorch.org/blog/understanding-gpu-memory-1/). + + +## Troubleshooting jobs that timeout + +If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs. +When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory. +To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html). diff --git a/docs/metrics.md b/docs/metrics.md index 568e3a87..9c46ebb5 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -1,5 +1,3 @@ -# Metrics - We support automatically collecting metrics such as 1. High level system metrics such as MFU, average loss, max loss and words per second along with some 2. Memory metrics to measure max VRAM consumption and the number of OOMs diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 6409dcd6..a42c887d 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -89,7 +89,7 @@ def reset_peak_stats(self): def build_device_memory_monitor(): device_memory_monitor = DeviceMemoryMonitor(device_type) logger.info( - f"{device_type.upper()} capacity: {device_memory_monitor.device_name}" + f"{device_type.upper()} capacity: {device_memory_monitor.device_name} " f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" ) return device_memory_monitor diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 7419f32d..9af771a2 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -34,18 +34,12 @@ def _validate(self): ) for d in (dp_replicate, cp, tp, pp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" - assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." - - dp = dp_replicate * dp_shard - if dp < 0: - dp = self.world_size // (cp * tp * pp) - self.dp_shard = dp_shard = dp // dp_replicate - assert dp_replicate >= 1 + assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." + if dp_shard < 0: + self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) assert dp_shard >= 1 - assert cp >= 1, cp - assert tp >= 1, tp - assert pp >= 1, pp + assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index efb0d25a..f85ae2b1 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -39,11 +39,11 @@ def trace_handler(prof): if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) - logger.info(f"Dumping traces at step {prof.step_num}") + logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") logger.info( - f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds" + f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" ) logger.info(f"Profiling active. Traces will be saved at {trace_dir}")