Skip to content

Commit

Permalink
some cleanups on docs and parallel_dims
Browse files Browse the repository at this point in the history
ghstack-source-id: 5b9f7938775fb81e69ee83b98c6413354016feeb
Pull Request resolved: #729
  • Loading branch information
tianyu-l committed Dec 12, 2024
1 parent 40a0873 commit 0186284
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 31 deletions.
12 changes: 2 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Our guiding principles when building `torchtitan`:

[![Welcome to torchtitan!](assets/images/titan_play_video.png)](https://youtu.be/ee5DOEqD35I?si=_B94PbVv0V5ZnNKE "Welcome to torchtitan!")

### Our torchtitan paper on arXiv
### torchtitan paper on arXiv

[![arXiv](https://img.shields.io/badge/arXiv-2410.06511-b31b1b.svg?style=plastic)](https://arxiv.org/abs/2410.06511)

Expand Down Expand Up @@ -61,7 +61,7 @@ You may want to see how the model is defined or how parallelism techniques are a
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc.
10. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
11. All options easily configured via [toml files](train_configs/)

We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
Expand Down Expand Up @@ -121,14 +121,6 @@ If your gpu count per node is not 8, adjust:

in the SBATCH command section.


## Debugging
### Troubleshooting jobs that timeout
If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs.
When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory.
To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html).


## License

This code is made available under [BSD 3 license](./LICENSE). However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models, data, etc.
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python3 scripts/convert_llama_to_dcp.py <input_dir> <output_dir>

This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.

## Steps
### Steps
1. ENABLE CHECKPOINTING
In your torchtitan training config, ensure that `enable_checkpoint` is set to True.
```
Expand Down
4 changes: 2 additions & 2 deletions docs/converging.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This note clarifies the recommended practices to follow when testing the loss converging of a new feature.

#### Disclaimers
### Disclaimers
1. We assume the vanilla 1D FSDP to be “correct”, and would serve as the baseline for comparisons. The correctness of FSDP can be verified by comparing with DDP on small models, which has been widely adopted and believed to be correct.

2. The focus is on the correctness of new distributed training techniques. For a new model size / architecture, the demonstration of loss-converging is not in the scope of this note.
Expand Down Expand Up @@ -54,7 +54,7 @@ Remarks
| 2D (MN GPUs) <br> e.g. M=8 | FSDP N, CP M | to verify CP with a larger degree |


#### Test results
### Test results
(TBA)

[^1]: Model initialization in a sharded setting can hardly match that in a single-device setting (or a differently sharded setting), because each time a random operator is called, the underlying RNG state offset is advanced by a quantized amount, often not aligned with the amount of randomness needed, thus “wasting” different amount of randomness on differently sharded settings.
Expand Down
6 changes: 3 additions & 3 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Custom Datasets in TorchTitan
# Custom Datasets in torchtitan

TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example.
`torchtitan` is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example.

## Quick Start
Locate the dataset configuration file:
Expand Down Expand Up @@ -60,7 +60,7 @@ In your training configuration file (`.toml`), set your dataset:
dataset = "wikipedia"
```

That's it! Your custom dataset is now ready to use with TorchTitan.
That's it! Your custom dataset is now ready to use with `torchtitan`.

## Key Points
- The DatasetConfig contains all necessary components for a dataset:
Expand Down
7 changes: 7 additions & 0 deletions docs/memory_profiler.md → docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --profiling.

You cab find the saved pickle files in your output folder.
To visualize a snapshot file, you can drag and drop it to <https://pytorch.org/memory_viz>. To learn more details on memory profiling, please visit this [tutorial](https://pytorch.org/blog/understanding-gpu-memory-1/).


## Troubleshooting jobs that timeout

If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs.
When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory.
To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html).
2 changes: 0 additions & 2 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# Metrics

We support automatically collecting metrics such as
1. High level system metrics such as MFU, average loss, max loss and words per second along with some
2. Memory metrics to measure max VRAM consumption and the number of OOMs
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def reset_peak_stats(self):
def build_device_memory_monitor():
device_memory_monitor = DeviceMemoryMonitor(device_type)
logger.info(
f"{device_type.upper()} capacity: {device_memory_monitor.device_name}"
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
return device_memory_monitor
Expand Down
14 changes: 4 additions & 10 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,12 @@ def _validate(self):
)
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
if dp_shard < 0:
self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp)
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp

assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def trace_handler(prof):
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)

logger.info(f"Dumping traces at step {prof.step_num}")
logger.info(f"Dumping profiler traces at step {prof.step_num}")
begin = time.monotonic()
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
logger.info(
f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds"
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
)

logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
Expand Down

0 comments on commit 0186284

Please sign in to comment.