Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W&B wandb support #699

Merged
merged 19 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ sentencepiece
tiktoken
blobfile
tabulate
wandb
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118
```

### Downloading a tokenizer
Expand All @@ -88,11 +88,11 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# Llama 3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
# Llama 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --tokenizer_path "original" --hf_token=...

# Llama 2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
# Llama 3.2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-3.2-3B --hf_token=...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was indeed for Llama 2, not for Llama 3.2.
I think we can remove Llama 2 files if they are not helpful anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can send a seperate PR deprecating Llama2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can we revert this change for now, as torchtitan doesn't support Llama 3.2 atm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

```

### Start a training run
Expand Down
30 changes: 30 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ def __init__(self):
""",
)

# WandB configs
self.parser.add_argument(
"--metrics.enable_wandb",
action="store_true",
help="Whether to log metrics to Weights & Biases",
)
self.parser.add_argument(
"--metrics.wandb_config.project",
type=str,
default="torchtitan",
help="Project name for WandB logging",
)
self.parser.add_argument(
"--metrics.wandb_config.entity",
type=str,
default=None,
help="Team/entity name for WandB logging",
)

# model configs
self.parser.add_argument(
"--model.name",
Expand Down Expand Up @@ -625,6 +644,17 @@ def _validate_config(self) -> None:
assert self.model.flavor
assert self.model.tokenizer_path

# Logging backend validations
if hasattr(self.metrics, "enable_tensorboard") and hasattr(
self.metrics, "enable_wandb"
):
if self.metrics.enable_tensorboard and self.metrics.enable_wandb:
logger.warning(
"Both TensorBoard and WandB logging were enabled. Using WandB only."
)
# Modify the config to disable TensorBoard
self.metrics.enable_tensorboard = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have slightly different opinion on this. I think we should let user control whether they want to use either or both. In rare cases, they may find enabling both to be useful (e.g. monitoring on wandb, but still keeping the TB logs).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright that's easy to fix and it's my preference as well but @fegin for visibility as well

Copy link
Contributor

@fegin fegin Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think unless they are showing different information that can be useful, I'm not sure why we need to enable both and for performance, this is going to be bad (but probably not noticeable). But I don't have a strong opinion if people prefer to enable both.

def parse_args_from_command_line(
self, args_list
) -> Tuple[argparse.Namespace, argparse.Namespace]:
Expand Down
90 changes: 68 additions & 22 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from torchtitan.parallelisms import ParallelDims
from torchtitan.utils import device_module, device_type

# Optional wandb import
try:
import wandb

WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
tianyu-l marked this conversation as resolved.
Show resolved Hide resolved

# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
Expand Down Expand Up @@ -93,26 +100,55 @@ def build_device_memory_monitor():
f"{device_type.upper} capacity: {device_memory_monitor.device_name} ({device_memory_monitor.device_index}) "
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)

return device_memory_monitor


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
def __init__(self, log_dir, tag, enable_tb, enable_wandb=False, wandb_config=None):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.use_wandb = False

if enable_wandb and WANDB_AVAILABLE:
self.use_wandb = True
if wandb.run is None:
project_name = (
wandb_config.get("project", "torchtitan")
if wandb_config
else "torchtitan"
)
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
wandb.init(
project=project_name,
config=wandb_config,
dir=log_dir,
)
logger.debug("WandB logging enabled")
elif enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")
else:
logger.warning("Neither TensorBoard nor WandB logging is enabled.")

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
"""Log metrics to the configured backend."""
if self.use_wandb:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
wandb_metrics["step"] = step
wandb.log(wandb_metrics)
elif self.writer is not None:
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self):
"""Clean up logging resources."""
if self.writer is not None:
self.writer.close()
if self.use_wandb and wandb.run is not None:
wandb.finish()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
Expand All @@ -126,35 +162,45 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
"""
Args:
job_config: Configuration object containing metrics settings.
parallel_dims: Used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
tag: Optional tag to prefix all metrics.

Returns:
MetricLogger instance configured based on the provided settings

parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
"""
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = tb_config.enable_tensorboard
if enable_tb:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)
metrics_config = job_config.metrics
log_dir = os.path.join(
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
)

enable_tb = metrics_config.enable_tensorboard
enable_wandb = metrics_config.enable_wandb
wandb_config = (
metrics_config.wandb_config if hasattr(metrics_config, "wandb_config") else None
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
)

if metrics_config.rank_0_only:
metrics_rank = _get_metrics_rank(parallel_dims)
is_metrics_rank = torch.distributed.get_rank() == metrics_rank
enable_tb = enable_tb and is_metrics_rank
enable_wandb = enable_wandb and is_metrics_rank
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)

return MetricLogger(log_dir, tag, enable_tb)
return MetricLogger(log_dir, tag, enable_tb, enable_wandb, wandb_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an extensibility perspective, it would be nicer to have this pattern:

    if not should_log:
        return DummyLogger()
    if enable_wandb:
        return WandbLogger(wandb_config)
    if enable_tb:
        return TensorBoardLogger(log_dir, tag)
    raise NotImplementedError

With

class DummyLogger:
    def log(self, *args, **kwargs):
        pass

    def close(self):
        pass


class TensorBoardLogger:
    def __init__(self, log_dir, tag):
        self.tag = tag
        self.writer = SummaryWriter(log_dir, max_queue=1000)

    def log(self, metrics: Dict[str, Any], step: int):
        for k, v in metrics.items():
            tag = k if self.tag is None else f"{self.tag}/{k}"
            self.writer.add_scalar(tag, v, step)

    def close(self):
        self.writer.close()

Since each class has its own separate logic, it's easier for forks to have their custom loggers with reduced conflicts. It also avoids the problem of optional types inside the classes (avoids assertions).

You could then import wandb inside the WandbLogger class to avoid importing it during startup

Copy link
Member Author

@msaroufim msaroufim Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the way this code looks but it goes back to the same issue we're discussing above around whether we should have multiple loggers in case people enable both wandb and tensorboard

Right now implicitly enabling both would enable whatever is the topmost condition

Also more curious but are you deliberately not using inheritance here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I did not look at the existing comments extensively.

For that case, you (or a fork) could add a logger class that takes a list of loggers and simply iterates over them.

    loggers = []
    if should_log:
        if enable_wandb:
            loggers.append(WandbLogger(wandb_config))
        if enable_tb:
            loggers.append(TensorBoardLogger(log_dir, tag))
    return Loggers(loggers)

This removes the need for DummyLogger, iterating over an empty list does nothing.

However, this only works well if the loggers agree to exposing the same interface under the hood, but my opinion is that this repo should stay simple and not think too much about this. From my previous experience, users don't enable two loggers at the same time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you've convinced me, 2 loggers seems clunky - incorporated your feedback

2 changes: 2 additions & 0 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ profile_freq = 100
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
enable_wandb = true
wandb_config = { project = "torchtitan", entity = "your-team" }
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

[model]
name = "llama3"
Expand Down
Loading