Skip to content

[WIP] Refactor profiler #2540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions recipes/configs/llama3_1/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,28 @@ log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
_component_: torchtune.training._profiler.TorchProfiler
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
cpu: True
profile_memory: True
# enabled: False

# #Output directory of trace artifacts
# output_dir: ${output_dir}/profiling_outputs

# #`torch.profiler.ProfilerActivity` types to trace
# cpu: True
# cuda: True

# #trace options passed to `torch.profiler.profile`
# profile_memory: False
# with_stack: False
# record_shapes: True
# with_flops: False

# # `torch.profiler.schedule` options:
# # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
# wait_steps: 5
# warmup_steps: 3
# active_steps: 2
# num_cycles: 1
286 changes: 107 additions & 179 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training import PROFILER_KEY
from torchtune.training.lr_schedulers import get_lr

from tqdm import tqdm
Expand Down Expand Up @@ -338,82 +338,14 @@ def setup(self, cfg: DictConfig) -> None:
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
# Set up profiler
self.profiler = config.instantiate(cfg.get(PROFILER_KEY))

# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
)

def _setup_profiler(
self, cfg_profiler: Optional[DictConfig] = None
) -> Union[torch.profiler.profile, DummyProfiler]:
"""
Parses the `profiler` section of top-level `cfg` and sets up profiler

Args:
cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
`recipe.main`). Default None.

Returns:
profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
that the instrumented training loop does not need to be changed profiling is disabled.

The profiler config can be provided in configs under the `profiler` key with the following layout:

.. code-block:: yaml
profiler:
enabled: bool

#Output directory of trace artifacts
output_dir: str

#`torch.profiler.ProfilerActivity` types to trace
cpu: bool
cuda: bool

#Trace options
profile_memory: bool
with_stack: bool
record_shapes: bool
with_flops: bool

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: int
warmup_steps: int
active_steps: int
num_cycles: int
"""

# Missing profiler section in config, assume disabled
if cfg_profiler is None:
cfg_profiler = DictConfig({"enabled": False})

# Check that component is included and set correctly
if cfg_profiler.get("_component_", None) is None:
cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
else:
assert (
cfg_profiler.get("_component_")
== "torchtune.training.setup_torch_profiler"
), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"

profiler, profiler_cfg = config.instantiate(cfg_profiler)

log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]

return profiler

def _setup_model(
self,
cfg_model: DictConfig,
Expand Down Expand Up @@ -667,121 +599,117 @@ def train(self) -> None:
self._optimizer.zero_grad()

# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0

self._profiler.start()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
pbar = tqdm(total=self._steps_per_epoch)
self._dataloader.sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
# Start tracking CUDA memory for active steps for just the first epoch
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history()
utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
if self._lr_scheduler is not None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
with self.profiler as profiler:
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
pbar = tqdm(total=self._steps_per_epoch)
self._dataloader.sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
# # Start tracking CUDA memory for active steps for just the first epoch
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
# and self._device.type == "cuda"
# ):
# torch.cuda.memory._record_memory_history()
utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
if self._lr_scheduler is not None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)

# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
log_dict = {
"loss": loss_to_log,
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if self._device.type != "cpu" and self._log_peak_memory_stats:
log_dict.update(
training.get_memory_stats(device=self._device)
"tokens_per_second_per_gpu": num_tokens
/ profiler.step_time,
}
if (
self._device.type != "cpu"
and self._log_peak_memory_stats
):
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)

# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
and self._device.type == "cuda"
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
self._profiler.step()

# Check if we should stop training for this epoch
if (
(idx + 1) // self._gradient_accumulation_steps
) == self.max_steps_per_epoch:
break

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

self._profiler.stop()
# Reset running stats for the next step
running_loss = 0
num_tokens = 0

# Stop tracking CUDA memory now that active steps are complete
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx
# == self.profiler_wait_steps
# + self.profiler_warmup_steps
# + self.profiler_active_steps
# and self._device.type == "cuda"
# ):
# torch.cuda.memory._record_memory_history(enabled=None)

profiler.step()

# Check if we should stop training for this epoch
if (
(idx + 1) // self._gradient_accumulation_steps
) == self.max_steps_per_epoch:
break

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

def cleanup(self) -> None:
self._metric_logger.close()
Expand Down
12 changes: 6 additions & 6 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._model_util import disable_dropout
from torchtune.training._profiler import (
DEFAULT_PROFILE_DIR,
DEFAULT_PROFILER_ACTIVITIES,
DEFAULT_SCHEDULE,
DEFAULT_TRACE_OPTS,
DummyProfiler,
# DEFAULT_PROFILE_DIR,
# DEFAULT_PROFILER_ACTIVITIES,
# DEFAULT_SCHEDULE,
# DEFAULT_TRACE_OPTS,
# DummyProfiler,
PROFILER_KEY,
setup_torch_profiler,
# setup_torch_profiler,
)
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.training.checkpointing import (
Expand Down
Loading
Loading