-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
W&B wandb support #699
W&B wandb support #699
Changes from 3 commits
2ec9e84
1d5ba4c
771ed9c
f4da960
e5bb635
b5a50cd
7824be6
f1a586f
6e55afa
e2c4060
a1b1a8d
b4cf6a3
b5e0ebf
1dad36d
8e5e03b
842a7c4
f923d97
30a13f5
cce9f42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ sentencepiece | |
tiktoken | ||
blobfile | ||
tabulate | ||
wandb |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,6 +148,25 @@ def __init__(self): | |
""", | ||
) | ||
|
||
# WandB configs | ||
self.parser.add_argument( | ||
"--metrics.enable_wandb", | ||
action="store_true", | ||
help="Whether to log metrics to Weights & Biases", | ||
) | ||
self.parser.add_argument( | ||
"--metrics.wandb_config.project", | ||
type=str, | ||
default="torchtitan", | ||
help="Project name for WandB logging", | ||
) | ||
self.parser.add_argument( | ||
"--metrics.wandb_config.entity", | ||
type=str, | ||
default=None, | ||
help="Team/entity name for WandB logging", | ||
) | ||
|
||
# model configs | ||
self.parser.add_argument( | ||
"--model.name", | ||
|
@@ -625,6 +644,17 @@ def _validate_config(self) -> None: | |
assert self.model.flavor | ||
assert self.model.tokenizer_path | ||
|
||
# Logging backend validations | ||
if hasattr(self.metrics, "enable_tensorboard") and hasattr( | ||
self.metrics, "enable_wandb" | ||
): | ||
if self.metrics.enable_tensorboard and self.metrics.enable_wandb: | ||
logger.warning( | ||
"Both TensorBoard and WandB logging were enabled. Using WandB only." | ||
) | ||
# Modify the config to disable TensorBoard | ||
self.metrics.enable_tensorboard = False | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have slightly different opinion on this. I think we should let user control whether they want to use either or both. In rare cases, they may find enabling both to be useful (e.g. monitoring on wandb, but still keeping the TB logs). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright that's easy to fix and it's my preference as well but @fegin for visibility as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, I think unless they are showing different information that can be useful, I'm not sure why we need to enable both and for performance, this is going to be bad (but probably not noticeable). But I don't have a strong opinion if people prefer to enable both. |
||
def parse_args_from_command_line( | ||
self, args_list | ||
) -> Tuple[argparse.Namespace, argparse.Namespace]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,13 @@ | |
from torchtitan.parallelisms import ParallelDims | ||
from torchtitan.utils import device_module, device_type | ||
|
||
# Optional wandb import | ||
try: | ||
import wandb | ||
|
||
WANDB_AVAILABLE = True | ||
except ImportError: | ||
WANDB_AVAILABLE = False | ||
tianyu-l marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# named tuple for passing device memory stats for logging | ||
DeviceMemStats = namedtuple( | ||
|
@@ -93,26 +100,55 @@ def build_device_memory_monitor(): | |
f"{device_type.upper} capacity: {device_memory_monitor.device_name} ({device_memory_monitor.device_index}) " | ||
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" | ||
) | ||
|
||
return device_memory_monitor | ||
|
||
|
||
class MetricLogger: | ||
def __init__(self, log_dir, tag, enable_tb): | ||
def __init__(self, log_dir, tag, enable_tb, enable_wandb=False, wandb_config=None): | ||
self.tag = tag | ||
self.writer: Optional[SummaryWriter] = None | ||
if enable_tb: | ||
self.use_wandb = False | ||
|
||
if enable_wandb and WANDB_AVAILABLE: | ||
self.use_wandb = True | ||
if wandb.run is None: | ||
project_name = ( | ||
wandb_config.get("project", "torchtitan") | ||
if wandb_config | ||
else "torchtitan" | ||
) | ||
msaroufim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
wandb.init( | ||
project=project_name, | ||
config=wandb_config, | ||
dir=log_dir, | ||
) | ||
logger.debug("WandB logging enabled") | ||
elif enable_tb: | ||
self.writer = SummaryWriter(log_dir, max_queue=1000) | ||
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}") | ||
else: | ||
logger.warning("Neither TensorBoard nor WandB logging is enabled.") | ||
|
||
def log(self, metrics: Dict[str, Any], step: int): | ||
if self.writer is not None: | ||
"""Log metrics to the configured backend.""" | ||
if self.use_wandb: | ||
wandb_metrics = { | ||
(k if self.tag is None else f"{self.tag}/{k}"): v | ||
for k, v in metrics.items() | ||
} | ||
wandb_metrics["step"] = step | ||
wandb.log(wandb_metrics) | ||
elif self.writer is not None: | ||
msaroufim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for k, v in metrics.items(): | ||
tag = k if self.tag is None else f"{self.tag}/{k}" | ||
self.writer.add_scalar(tag, v, step) | ||
|
||
def close(self): | ||
"""Clean up logging resources.""" | ||
if self.writer is not None: | ||
self.writer.close() | ||
if self.use_wandb and wandb.run is not None: | ||
wandb.finish() | ||
|
||
|
||
def _get_metrics_rank(parallel_dims: ParallelDims) -> int: | ||
|
@@ -126,35 +162,45 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int: | |
metrics_log_rank = (world_size // pp_size) * (pp_size - 1) | ||
else: | ||
metrics_log_rank = 0 | ||
|
||
return metrics_log_rank | ||
|
||
|
||
def build_metric_logger( | ||
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None | ||
): | ||
""" | ||
Args: | ||
job_config: Configuration object containing metrics settings. | ||
parallel_dims: Used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. | ||
tag: Optional tag to prefix all metrics. | ||
|
||
Returns: | ||
MetricLogger instance configured based on the provided settings | ||
|
||
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. | ||
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is | ||
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline | ||
parallelism is enabled, without forcing logging from all ranks to capture loss information. | ||
""" | ||
dump_dir = job_config.job.dump_folder | ||
tb_config = job_config.metrics | ||
save_tb_folder = tb_config.save_tb_folder | ||
# since we don't have run id, use current minute as the identifier | ||
datetime_str = datetime.now().strftime("%Y%m%d-%H%M") | ||
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) | ||
|
||
enable_tb = tb_config.enable_tensorboard | ||
if enable_tb: | ||
logger.info( | ||
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" | ||
) | ||
if tb_config.rank_0_only: | ||
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) | ||
else: | ||
rank_str = f"rank_{torch.distributed.get_rank()}" | ||
log_dir = os.path.join(log_dir, rank_str) | ||
metrics_config = job_config.metrics | ||
log_dir = os.path.join( | ||
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") | ||
) | ||
|
||
enable_tb = metrics_config.enable_tensorboard | ||
enable_wandb = metrics_config.enable_wandb | ||
wandb_config = ( | ||
metrics_config.wandb_config if hasattr(metrics_config, "wandb_config") else None | ||
msaroufim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if metrics_config.rank_0_only: | ||
metrics_rank = _get_metrics_rank(parallel_dims) | ||
is_metrics_rank = torch.distributed.get_rank() == metrics_rank | ||
enable_tb = enable_tb and is_metrics_rank | ||
enable_wandb = enable_wandb and is_metrics_rank | ||
else: | ||
rank_str = f"rank_{torch.distributed.get_rank()}" | ||
log_dir = os.path.join(log_dir, rank_str) | ||
|
||
return MetricLogger(log_dir, tag, enable_tb) | ||
return MetricLogger(log_dir, tag, enable_tb, enable_wandb, wandb_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From an extensibility perspective, it would be nicer to have this pattern: if not should_log:
return DummyLogger()
if enable_wandb:
return WandbLogger(wandb_config)
if enable_tb:
return TensorBoardLogger(log_dir, tag)
raise NotImplementedError With class DummyLogger:
def log(self, *args, **kwargs):
pass
def close(self):
pass
class TensorBoardLogger:
def __init__(self, log_dir, tag):
self.tag = tag
self.writer = SummaryWriter(log_dir, max_queue=1000)
def log(self, metrics: Dict[str, Any], step: int):
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
def close(self):
self.writer.close() Since each class has its own separate logic, it's easier for forks to have their custom loggers with reduced conflicts. It also avoids the problem of optional types inside the classes (avoids assertions). You could then import There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the way this code looks but it goes back to the same issue we're discussing above around whether we should have multiple loggers in case people enable both wandb and tensorboard Right now implicitly enabling both would enable whatever is the topmost condition Also more curious but are you deliberately not using inheritance here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, I did not look at the existing comments extensively. For that case, you (or a fork) could add a logger class that takes a list of loggers and simply iterates over them. loggers = []
if should_log:
if enable_wandb:
loggers.append(WandbLogger(wandb_config))
if enable_tb:
loggers.append(TensorBoardLogger(log_dir, tag))
return Loggers(loggers) This removes the need for However, this only works well if the loggers agree to exposing the same interface under the hood, but my opinion is that this repo should stay simple and not think too much about this. From my previous experience, users don't enable two loggers at the same time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you've convinced me, 2 loggers seems clunky - incorporated your feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was indeed for Llama 2, not for Llama 3.2.
I think we can remove Llama 2 files if they are not helpful anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can send a seperate PR deprecating Llama2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Can we revert this change for now, as torchtitan doesn't support Llama 3.2 atm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#712