Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W&B wandb support #699

Merged
merged 19 commits into from
Dec 3, 2024
Merged

W&B wandb support #699

merged 19 commits into from
Dec 3, 2024

Conversation

msaroufim
Copy link
Member

@msaroufim msaroufim commented Nov 25, 2024

This PR adds experimental wandb support, not sure this is "landable" considering y'all uses tensorboard by default. Personally I vastly prefer wandb because I can share my training runs with a link and don't need to muck around with ssh tunneling so I'm just opening this since I'm using it myself. If there's interest I can work to land this.

To use this you just kick of a training as usual with CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh but also run wandb login and paste in your token
Screenshot 2024-11-25 at 12 16 20 PM

Changes in logs will look like

Screenshot 2024-11-25 at 12 27 42 PM

Also only slightly related but llama 3 tokenizer is not available on hf anymore so added instructions for 3.1 and 3.2

Click here for detailed logs. [rank0]:2024-11-25 11:33:24,320 - root - INFO - Dumping traces at step 1000 [rank0]:2024-11-25 11:33:24,576 - root - INFO - Finished dumping traces in 0.26 seconds [rank0]:2024-11-25 11:33:24,577 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:wandb: [rank0]:wandb: [rank0]:wandb: Run history: [rank0]:wandb: loss_metrics/global_avg_loss █▆▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: loss_metrics/global_max_loss █▇▄▄▃▃▄▃▃▆▃▃▃▃▃▃▂▂▂▂▃▂▂▃▁▂▂▂▁▃▂▁▂▁▂▂▁▄▁▁ [rank0]:wandb: memory/max_active(%) ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: memory/max_active(GiB) ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: memory/max_reserved(%) ▁███████████████████████████████████████ [rank0]:wandb: memory/max_reserved(GiB) ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: memory/num_alloc_retries ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: memory/num_ooms ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ [rank0]:wandb: mfu(%) ▁███████▇██████▇█████████▇█▇████████████ [rank0]:wandb: step ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇████ [rank0]:wandb: time_metrics/data_loading(%) ▁▁▁▁▂▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▂ [rank0]:wandb: time_metrics/data_loading(s) ▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂ [rank0]:wandb: time_metrics/end_to_end(s) ▁▇▇▇▇█▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇██▇▇▇█▇▇▇▇▇▇█▇█▇▇▇ [rank0]:wandb: wps ███▁████▄█▇▅████████▅▄████▇███▇▄████▇██▇ [rank0]:wandb: [rank0]:wandb: Run summary: [rank0]:wandb: loss_metrics/global_avg_loss 4.53519 [rank0]:wandb: loss_metrics/global_max_loss 4.99517 [rank0]:wandb: memory/max_active(%) 43.33611 [rank0]:wandb: memory/max_active(GiB) 41.17145 [rank0]:wandb: memory/max_reserved(%) 52.19301 [rank0]:wandb: memory/max_reserved(GiB) 49.58594 [rank0]:wandb: memory/num_alloc_retries 0 [rank0]:wandb: memory/num_ooms 0 [rank0]:wandb: mfu(%) 30.75216 [rank0]:wandb: step 1000 [rank0]:wandb: time_metrics/data_loading(%) 1.01461 [rank0]:wandb: time_metrics/data_loading(s) 0.01583 [rank0]:wandb: time_metrics/end_to_end(s) 1.55993 [rank0]:wandb: wps 5251.52034 [rank0]:wandb: [rank0]:wandb: 🚀 View run skilled-glitter-1 at: https://wandb.ai/sahancpal-meta/torchtitan/runs/r1zqr75b

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 25, 2024
@msaroufim msaroufim changed the title [not for land] W&B wandb support W&B wandb support Nov 25, 2024
"W&B logging requested but wandb package is not installed. Continuing without W&B logging."
)
enable_wandb = False
elif enable_wandb:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit: W&B logs it's own init statement, so having one here might be overkill.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point, w&b is actually quite loud with emojis

@nenomigami
Copy link

I personally use wandb and have integrated it into my code. However, the watch functionality of wandb throws an error when applying FSDP2. Do you have any suggestions or ideas for resolving this issue?

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a good idea to add support for wandb as it continues to gain popularity. Would it be possible to make wandb and tensorboard mutually exclusive? It doesn't worth to log into two places.

@msaroufim
Copy link
Member Author

@fegin sure thing I can look into that, just to be clear on the gap to merge

  1. would w&b be an optional dependency or not?
  2. I can make the logging mutually exclusive
  3. Remove some of the info messages
  4. I can add a test but noticed tensorboard isn't tested either
  5. Any thoughts on the config in the toml name, good enough as is?

@fegin
Copy link
Contributor

fegin commented Nov 27, 2024

would w&b be an optional dependency or not?

I would suggest that it is an optional dependency. While wandb becomes increasingly popular, most users are still using tensorboard.

I can make the logging mutually exclusive

Thanks, that would be good.

I can add a test but noticed tensorboard isn't tested either

I think a e2e run command with wandb screenshot would be good eough.

Any thoughts on the config in the toml name, good enough as is?

Name is okay, but please check config_manager.py and add the config into it.
Also, please raise an exception if wandb is enabled but is not available. Implicit fallback is not a good idea.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping add wandb!

README.md Outdated
Comment on lines 94 to 95
# Llama 3.2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-3.2-3B --hf_token=...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was indeed for Llama 2, not for Llama 3.2.
I think we can remove Llama 2 files if they are not helpful anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can send a seperate PR deprecating Llama2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can we revert this change for now, as torchtitan doesn't support Llama 3.2 atm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb_metrics["step"] = step
wandb.log(wandb_metrics)

def log_memory_stats(self, memory_stats: DeviceMemStats, step: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to not be called anywhere?

@msaroufim
Copy link
Member Author

Made the checks mutually exclusive, if both are set I default to WanDB

➜  torchtitan git:(msaroufim/wandb) ✗ CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh

+ NGPU=7
+ LOG_RANK=0
+ CONFIG_FILE=./train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ torchrun --nproc_per_node=7 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama3_8b.toml
W1202 12:59:06.793000 2761933 site-packages/torch/distributed/run.py:793] 
W1202 12:59:06.793000 2761933 site-packages/torch/distributed/run.py:793] *****************************************
W1202 12:59:06.793000 2761933 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1202 12:59:06.793000 2761933 site-packages/torch/distributed/run.py:793] *****************************************
[rank0]:Both TensorBoard and WandB logging were enabled. Using WandB only.
[rank0]:2024-12-02 12:59:11,652 - root - INFO - Starting job: Llama 3 8B training
[rank0]:2024-12-02 12:59:11,653 - root - INFO - Deterministic training off
[rank0]:2024-12-02 12:59:12,672 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:2024-12-02 12:59:12,674 - root - INFO - <built-in method upper of str object at 0x7fd6af8e70b0> capacity: NVIDIA H100 (0) with 95.00GiB memory
[rank0]:2024-12-02 12:59:12,732 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:2024-12-02 12:59:12,732 - root - INFO - Building 1-D device mesh with ['dp'], [7]
[rank0]:2024-12-02 12:59:12,733 - root - INFO - Building tiktoken tokenizer locally from ./torchtitan/datasets/tokenizer/original/tokenizer.model
[rank0]:2024-12-02 12:59:12,874 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:2024-12-02 12:59:12,875 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:2024-12-02 12:59:22,127 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, norm_type='rmsnorm')
[rank0]:2024-12-02 12:59:22,265 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-12-02 12:59:22,266 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2024-12-02 12:59:22,330 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.4
[rank0]:/home/marksaroufim/.conda/envs/titan/lib/python3.10/site-packages/torch/nn/init.py:51: UserWarning: No PYTORCH_KERNEL_CACHE_PATH or HOME environment variable set! This disables kernel caching. (Triggered internally at ../aten/src/ATen/native/cuda/jit_utils.cpp:1426.)
[rank0]:  tensor.erfinv_()
[rank0]:2024-12-02 12:59:30,682 - root - INFO - <built-in method upper of str object at 0x7fd6af8e70b0> memory usage for model: 4.30GiB(4.52%)
[rank0]:wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[rank0]:wandb: Currently logged in as: formalsystem (sahancpal-meta). Use `wandb login --relogin` to force relogin
[rank0]:wandb: WARNING Path ./outputs/tb/20241202-1259/wandb/ wasn't writable, using system temp directory.
[rank0]:wandb: Tracking run with wandb version 0.18.7
[rank0]:wandb: Run data is saved locally in /tmp/wandb/run-20241202_125931-yjcx6og9
[rank0]:wandb: Run `wandb offline` to turn off syncing.
[rank0]:wandb: Syncing run decent-feather-10
[rank0]:wandb: ⭐️ View project at https://wandb.ai/sahancpal-meta/torchtitan
[rank0]:wandb: 🚀 View run at https://wandb.ai/sahancpal-meta/torchtitan/runs/yjcx6og9
[rank0]:2024-12-02 12:59:31,762 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 7, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:2024-12-02 12:59:31,763 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:2024-12-02 12:59:39,887 - root - INFO - step:  1  loss: 12.2380  memory: 43.22GiB(45.49%)  wps: 1,008  mfu: 5.91%
[rank0]:2024-12-02 12:59:39,887 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-12-02 12:59:53,798 - root - INFO - step: 10  loss: 10.7523  memory: 51.77GiB(54.50%)  wps: 5,300  mfu: 31.04%
[rank0]:2024-12-02 13:00:09,305 - root - INFO - step: 20  loss:  9.0168  memory: 51.77GiB(54.50%)  wps: 5,285  mfu: 30.95%

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please also add a short note on how to use W&B? similar to
https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L63
and
https://github.com/pytorch/torchtitan/blob/main/README.md?plain=1#L103

It might make sense to have a separate .md in docs/ containing both TB and W&B instructions.

README.md Outdated
Comment on lines 94 to 95
# Llama 3.2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-3.2-3B --hf_token=...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can we revert this change for now, as torchtitan doesn't support Llama 3.2 atm?

Comment on lines 647 to 657
# Logging backend validations
if hasattr(self.metrics, "enable_tensorboard") and hasattr(
self.metrics, "enable_wandb"
):
if self.metrics.enable_tensorboard and self.metrics.enable_wandb:
logger.warning(
"Both TensorBoard and WandB logging were enabled. Using WandB only."
)
# Modify the config to disable TensorBoard
self.metrics.enable_tensorboard = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have slightly different opinion on this. I think we should let user control whether they want to use either or both. In rare cases, they may find enabling both to be useful (e.g. monitoring on wandb, but still keeping the TB logs).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright that's easy to fix and it's my preference as well but @fegin for visibility as well

Copy link
Contributor

@fegin fegin Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think unless they are showing different information that can be useful, I'm not sure why we need to enable both and for performance, this is going to be bad (but probably not noticeable). But I don't have a strong opinion if people prefer to enable both.

torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
train_configs/llama3_8b.toml Outdated Show resolved Hide resolved

return MetricLogger(log_dir, tag, enable_tb)
return MetricLogger(log_dir, tag, enable_tb, enable_wandb, wandb_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an extensibility perspective, it would be nicer to have this pattern:

    if not should_log:
        return DummyLogger()
    if enable_wandb:
        return WandbLogger(wandb_config)
    if enable_tb:
        return TensorBoardLogger(log_dir, tag)
    raise NotImplementedError

With

class DummyLogger:
    def log(self, *args, **kwargs):
        pass

    def close(self):
        pass


class TensorBoardLogger:
    def __init__(self, log_dir, tag):
        self.tag = tag
        self.writer = SummaryWriter(log_dir, max_queue=1000)

    def log(self, metrics: Dict[str, Any], step: int):
        for k, v in metrics.items():
            tag = k if self.tag is None else f"{self.tag}/{k}"
            self.writer.add_scalar(tag, v, step)

    def close(self):
        self.writer.close()

Since each class has its own separate logic, it's easier for forks to have their custom loggers with reduced conflicts. It also avoids the problem of optional types inside the classes (avoids assertions).

You could then import wandb inside the WandbLogger class to avoid importing it during startup

Copy link
Member Author

@msaroufim msaroufim Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the way this code looks but it goes back to the same issue we're discussing above around whether we should have multiple loggers in case people enable both wandb and tensorboard

Right now implicitly enabling both would enable whatever is the topmost condition

Also more curious but are you deliberately not using inheritance here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I did not look at the existing comments extensively.

For that case, you (or a fork) could add a logger class that takes a list of loggers and simply iterates over them.

    loggers = []
    if should_log:
        if enable_wandb:
            loggers.append(WandbLogger(wandb_config))
        if enable_tb:
            loggers.append(TensorBoardLogger(log_dir, tag))
    return Loggers(loggers)

This removes the need for DummyLogger, iterating over an empty list does nothing.

However, this only works well if the loggers agree to exposing the same interface under the hood, but my opinion is that this repo should stay simple and not think too much about this. From my previous experience, users don't enable two loggers at the same time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you've convinced me, 2 loggers seems clunky - incorporated your feedback

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metrics.py structure LGTM now. Thanks!

docs/metrics.md Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Thank you very much!
Had some final inline comments.

torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
torchtitan/metrics.py Outdated Show resolved Hide resolved
msaroufim and others added 4 commits December 2, 2024 23:49
@msaroufim
Copy link
Member Author

Thanks again for all the feedback. Had a final lint issue I fixed now. Would you be OK if I merge this or would you rather do so?

@fegin fegin merged commit 6a6f755 into main Dec 3, 2024
6 checks passed
action="store_true",
help="Whether to log metrics to TensorBoard",
default=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was flipping the enable_color_printing default intended? If so, the action store_true does not fit well anymore. It's now true by default

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants