diff --git a/recipes/configs/llama4/scout_17B_16E_dpo_full.yaml b/recipes/configs/llama4/scout_17B_16E_dpo_full.yaml new file mode 100644 index 0000000000..89a5b039e7 --- /dev/null +++ b/recipes/configs/llama4/scout_17B_16E_dpo_full.yaml @@ -0,0 +1,106 @@ +# Config for multi-device DPO finetuning in full_dpo_distributed.py +# using a Llama4 17Bx16E MoE model with 2D parallelism +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct +# +# To launch on 8 devices, run the following command from root: +# tune run --nproc_per_node 8 full_dpo_distributed --config llama4/scout_17B_16E_dpo +# +# You can add specific overrides through the command line. For example, to use a larger bsz: +# tune run --nproc_per_node 8 full_dpo_distributed --config llama4/scout_17B_16E_dpo batch_size=8 +# +# This config is designed for 8xA100 or 16xH100 machines. + +output_dir: /tmp/torchtune/llama4_17Bx16E/dpo + +# Modeling arguments +model: + _component_: torchtune.models.llama4.llama4_scout_17b_16e + +# 2D Parallelism configuration +tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8 +tensor_parallel_plan: + _component_: torchtune.models.llama4.decoder_only_tp_plan +data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP +data_parallel_replicate_dim: 1 + +tokenizer: + _component_: torchtune.models.llama4.llama4_transform + path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model + max_seq_len: null + max_num_tiles: 16 + +# Base model checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00050" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA4 + +# Reference model checkpointer (for DPO) +ref_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00050" + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA4 + +resume_from_checkpoint: False + +# Dataset +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset + packed: False +seed: null +shuffle: True + +# Training arguments +epochs: 1 +max_steps_per_epoch: null +batch_size: 1 +gradient_accumulation_steps: 8 # Use to increase effective batch size +optimizer: + _component_: torch.optim.AdamW + lr: 5e-7 # Lower learning rate for DPO + fused: False +loss: + _component_: torchtune.rlhf.loss.DPOLoss + beta: 0.1 +clip_grad_norm: 1.0 + +# cuda, cpu, rocm, xpu... +device: cuda + +# Memory management / performance +enable_activation_checkpointing: True +enable_activation_offloading: True +fsdp_cpu_offload: False # Set to False - keeping optimizer states on GPU +fsdp_reshard_after_forward: True +compile: False # torch.compile, set to true for perf/memory improvement + +# Reduced precision +dtype: bf16 + +# Log metrics during training +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Useful for understanding how to optimize memory and performance +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + +# Float8 training support +enable_fp8_training: False +fp8_recipe_name: null diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py index 0d6a632de7..52db9cf3ec 100644 --- a/recipes/full_dpo_distributed.py +++ b/recipes/full_dpo_distributed.py @@ -14,7 +14,10 @@ from omegaconf import DictConfig, ListConfig from torch import nn from torch.distributed import destroy_process_group, init_process_group +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer +from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, rlhf, training, utils @@ -23,8 +26,15 @@ from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.rlhf import ChosenRejectedOutputs from torchtune.training import disable_dropout, DummyProfiler, PROFILER_KEY +from torchtune.training.checkpointing._checkpoint_client import ( + CheckpointClient, + TrainingProgress, +) from torchtune.training.lr_schedulers import get_lr -from torchtune.utils import get_world_size_and_rank +from torchtune.training.quantization import ( + convert_to_float8_training, + is_fp8_tensorwise_scaling, +) from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -116,7 +126,8 @@ class FullDPORecipeDistributed(FTRecipeInterface): """ def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) + device_type = cfg.device + self._device = utils.get_device(device=device_type) self._dtype = training.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: @@ -124,38 +135,62 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) + # Set up the backend for distributed training (NCCL, GLOO, etc.) + self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) + self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False) + self.distributed_backend = training.get_distributed_backend( + device_type, + offload_ops_to_cpu=self.fsdp_cpu_offload + or self._enable_async_checkpointing, + ) + init_process_group(self.distributed_backend) + + # Initialize distributed variables + self.world_size, self.rank = utils.get_world_size_and_rank() + self._is_rank_zero = self.rank == 0 + self.tp_plan = cfg.get("tensor_parallel_plan", None) + self.tp_degree = cfg.get("tensor_parallel_dim", 1) + if self.tp_degree > 1 and self.tp_plan is None: + raise ValueError( + "Tensor Parallel plan needs to be provided when tensor parallel is enabled." + ) + data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer + data_replicate = cfg.get("data_parallel_replicate_dim", 1) + + # Set up n-d device mesh + self.parallel_dims = training.ParallelDims( + dp_replicate=data_replicate, + dp_shard=data_shard, + tp=self.tp_degree, + world_size=self.world_size, + ) + self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type) + if self.parallel_dims.dp_enabled: + dp_mesh = self.world_mesh["dp"] + self.dp_degree, self.dp_rank = ( + dp_mesh.size(), + dp_mesh.get_local_rank(), + ) + else: + self.dp_degree, self.dp_rank = 1, 0 + # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - - if self._log_peak_memory_stats and self._device.type != "cuda": + if self._log_peak_memory_stats and device_type != "cuda": log.info( "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." ) self._log_peak_memory_stats = False - self.world_size, self.rank = get_world_size_and_rank() - self._is_rank_zero = self.rank == 0 - # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) self._clip_grad_norm = cfg.get("clip_grad_norm", None) - - # Optimizer in backward is not compatible with gradient accumulation or gradient clipping - if self._optimizer_in_bwd: - if self._clip_grad_norm is not None: - raise RuntimeError( - "Gradient clipping is not supported with optimizer in bwd." - "Please set clip_grad_norm=None, or optimizer_in_bwd=False." - ) - if self._gradient_accumulation_steps > 1: - raise RuntimeError( - "Gradient accumulation is not supported with optimizer in bwd." - "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." - ) + self._checkpoint_client = CheckpointClient(cfg) + self._enable_fp8_training = cfg.get("enable_fp8_training", False) + self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( @@ -190,21 +225,6 @@ def __init__(self, cfg: DictConfig) -> None: self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - def _load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: - """ - Extract the checkpoint state from file and validate. If resume_from_checkpoint - is True, this also includes the recipe state. - """ - self._checkpointer = config.instantiate( - cfg_checkpointer, - should_load_recipe_state=self._resume_from_checkpoint, - ) - checkpoint_dict = self._checkpointer.load_checkpoint() - - if self._resume_from_checkpoint: - self._update_recipe_state(checkpoint_dict) - return checkpoint_dict - def _load_ref_checkpoint(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the reference model checkpoint state from file. @@ -260,6 +280,11 @@ def setup(self, cfg: DictConfig) -> None: Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. """ + if self.fsdp_cpu_offload: + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -267,34 +292,55 @@ def setup(self, cfg: DictConfig) -> None: self._metric_logger.log_config(cfg) # Load the base model - checkpoint_dict = self._load_checkpoint(cfg.checkpointer) + checkpoint_dict = self._checkpoint_client.load_base_checkpoint() ref_checkpoint_dict = self._load_ref_checkpoint(cfg.ref_checkpointer) + if self._resume_from_checkpoint: + # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously + # using the DistributedCheckpointer. + # Therefore the recipe needs to load the distributed checkpoint to restore the training + # progress. + if self._enable_async_checkpointing: + try: + checkpoint_dict = ( + self._checkpoint_client.load_distributed_checkpoint( + self._model, + self._optimizer, + ) + ) + except Exception as e: + log.warning( + f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint." + ) + + # Update the recipe state from the checkpoint state dict. + self._update_recipe_state(checkpoint_dict) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, + model_state_dict=checkpoint_dict[training.MODEL_KEY], + is_reference_model=False, enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_offloading=self._enable_activation_offloading, - custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), - model_state_dict=checkpoint_dict[training.MODEL_KEY], + custom_sharded_layers=cfg.get("custom_sharded_layers", None), ) - # TODO (@SalmanMohammadi) investigate TP for ref model - self._ref_model = self._setup_reference_model( + self._ref_model = self._setup_model( cfg_model=cfg.model, - fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), model_state_dict=ref_checkpoint_dict, - custom_sharded_layers=cfg.get("custom_sharded_layers", None), + is_reference_model=True, + fsdp_cpu_offload=False, + reshard_after_forward=False, + custom_sharded_layers=None, ) self._tokenizer = config.instantiate(cfg.tokenizer) self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint @@ -416,11 +462,12 @@ def _setup_profiler( def _setup_model( self, cfg_model: DictConfig, - enable_activation_checkpointing: bool, - enable_activation_offloading: bool, - fsdp_cpu_offload: bool, - reshard_after_forward: bool, model_state_dict: Dict[str, Any], + is_reference_model: bool = False, + enable_activation_checkpointing: bool = False, + enable_activation_offloading: bool = False, + fsdp_cpu_offload: bool = False, + reshard_after_forward: bool = True, custom_sharded_layers: Optional[List[str]] = None, ) -> nn.Module: """ @@ -429,11 +476,24 @@ def _setup_model( the right dtype b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since full state dicts are loaded with ``torch.load(mmap=True)`` - """ + Args: + cfg_model: Model configuration + model_state_dict: Model state dictionary to load + is_reference_model: Whether this is a reference model (inference-only) + enable_activation_checkpointing: Whether to enable activation checkpointing + enable_activation_offloading: Whether to enable activation offloading + fsdp_cpu_offload: Whether to offload FSDP parameters to CPU (only for base model) + reshard_after_forward: Whether to reshard parameters after forward pass (only for base model) + custom_sharded_layers: Custom layers to shard (only for base model) + + Returns: + Initialized model + """ + model_type = "reference model" if is_reference_model else "model" utils.log_rank_zero( log, - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + f"Distributed training is enabled. Instantiating {model_type} and loading checkpoint on Rank 0 ...", ) init_start = time.perf_counter() @@ -443,25 +503,70 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) - # original activation checkpointing (full) - flip the condition above - if enable_activation_checkpointing: + # Only apply FP8 training to the base model, not the reference model + if not is_reference_model and self._enable_fp8_training: + # Requires https://github.com/pytorch/pytorch/pull/148922 + if torch.__version__ < "2.8.0.dev20250318": + raise RuntimeError( + "Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later." + ) + if self.tp_plan is not None: + raise ValueError( + "FP8 training does not support tensor parallelism yet. " + "This will be enabled in the near future." + ) + model = convert_to_float8_training(model, self._fp8_recipe_name) + + # Apply tensor parallelism to the model (for both base and reference models) + if self.parallel_dims.tp_enabled: + if ( + not is_reference_model + and not self.parallel_dims.dp_enabled + and fsdp_cpu_offload + ): + raise ValueError( + "Tensor parallelism is not supported with FSDP CPU offloading when data parallelism is disabled." + ) + # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel + model = training.prepare_mha_for_tp(model, self.world_mesh["tp"]) + if self.tp_plan is not None: + self.tp_plan = config.instantiate( + self.tp_plan, + model=model, + ) + parallelize_module( + model, + self.world_mesh["tp"], + parallelize_plan=self.tp_plan, + ) + + # Apply activation checkpointing if enabled (only for base model) + if not is_reference_model and enable_activation_checkpointing: training.set_activation_checkpointing( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding - fsdp_shard_conditions = [ - partial( - training.get_shard_conditions, - names_to_match=custom_sharded_layers, + # Apply Fully Sharded Data Parallelism to the model (only for base model) + if not is_reference_model and self.parallel_dims.dp_shard_enabled: + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + + if self.parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + else: + dp_mesh_dim_names = ("dp_shard",) + + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + dp_mesh=self.world_mesh[dp_mesh_dim_names], ) - ] - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - ) with training.set_default_dtype(self._dtype), self._device: for m in model.modules(): @@ -479,99 +584,18 @@ def _setup_model( cpu_offload=fsdp_cpu_offload, ) - # activation offloading - self.activations_handling_ctx = training.get_act_offloading_ctx_manager( - model, enable_activation_offloading - ) - - # Ensure no params and buffers are on meta device - training.validate_no_params_on_meta_device(model) - - utils.log_rank_zero( - log, - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", - ) - - if self._is_rank_zero: - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) - - # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs - # between ref policy and current policy - disable_dropout(model) - - # synchronize before training begins - torch.distributed.barrier() - - return model - - def _setup_reference_model( - self, - cfg_model: DictConfig, - fsdp_cpu_offload: bool, - reshard_after_forward: bool, - model_state_dict: Dict[str, Any], - custom_sharded_layers: Optional[List[str]] = None, - ) -> nn.Module: - """ - Similar to `self._setup_model`: - a. To minimize GPU peak memory, we initialize the model on meta device with - the right dtype - b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since - full state dicts are loaded with ``torch.load(mmap=True)`` - - Additionally, since the reference model is inference-only, we omit some training-specific - optimizations. - """ - - utils.log_rank_zero( - log, - "FSDP is enabled. Instantiating reference model and loading checkpoint on Rank 0 ...", - ) - init_start = time.perf_counter() - - with training.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) - - if self._compile: - training.compile_model(model, verbose=self._is_rank_zero) - - # For FSDP sharding - fsdp_shard_conditions = [ - partial( - training.get_shard_conditions, - names_to_match=custom_sharded_layers, + # Set up activation offloading context for the base model + if not is_reference_model: + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading ) - ] - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - ) - - with training.set_default_dtype(self._dtype), self._device: - for m in model.modules(): - # RoPE is not covered in state dict - if hasattr(m, "rope_init"): - m.rope_init() - - # This method will convert the full model state dict into a sharded state - # dict and load into the model - training.load_from_full_model_state_dict( - model, - model_state_dict, - self._device, - strict=True, - cpu_offload=fsdp_cpu_offload, - ) # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) utils.log_rank_zero( log, - f"Instantiating reference model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + f"Instantiating {model_type} and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", ) if self._is_rank_zero: @@ -582,10 +606,11 @@ def _setup_reference_model( # between ref policy and current policy disable_dropout(model) - for p in model.parameters(): - p.requires_grad = False - - model.eval() + # For reference model, set to eval mode and disable gradients + if is_reference_model: + for p in model.parameters(): + p.requires_grad = False + model.eval() # synchronize before training begins torch.distributed.barrier() @@ -595,70 +620,55 @@ def _setup_reference_model( def _setup_optimizer( self, cfg_optimizer: DictConfig, - optimizer_in_bwd: bool = False, opt_state_dict: Optional[Dict[str, Any]] = None, - ) -> Optional[Optimizer]: - if optimizer_in_bwd: - # Maintain a dict of optims for every parameter. - optim_dict = { - param: config.instantiate(cfg_optimizer, [param]) - for param in self._model.parameters() - } - - # Register optimizer step hooks on the model to run optimizer in backward. - training.register_optim_in_bwd_hooks( - model=self._model, optim_dict=optim_dict - ) - # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( - model=self._model, optim_dict=optim_dict + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + self._model, + optimizer, + opt_state_dict, + self._device, ) - # Load optimizer states for each param. If optimizer states are being restored in an optimizer in - # backward run, these need to have been saved with the same setting. Cannot restore from runs that - # did not use optimizer in backward. - if opt_state_dict is not None: - for param in opt_state_dict.keys(): - try: - training.load_from_full_optimizer_state_dict( - self._model, - self._optim_ckpt_wrapper.optim_map[param], - opt_state_dict[param], - self._device, - ) - except BaseException as e: - raise RuntimeError( - "Failed loading in-backward optimizer checkpoints." - "Please make sure run being restored from was using in-backward optimizer." - ) from e - utils.log_rank_zero(log, "In-backward optimizers are set up.") - return None - else: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - training.load_from_full_optimizer_state_dict( - self._model, - optimizer, - opt_state_dict, - self._device, - ) - utils.log_rank_zero(log, "Optimizer and loss are initialized.") - return optimizer + utils.log_rank_zero(log, "Optimizer and loss are initialized.") + return optimizer def _setup_lr_scheduler( self, - cfg_lr_scheduler: DictConfig, + cfg_lr_scheduler: Optional[DictConfig], num_training_steps: int, last_epoch: int, - ) -> Optimizer: + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + if self._is_rank_zero: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + # Instantiate the learning rate scheduler lr_scheduler = config.instantiate( cfg_lr_scheduler, self._optimizer, num_training_steps=num_training_steps, last_epoch=last_epoch, ) + if self._is_rank_zero: log.info("Learning rate scheduler is initialized.") + return lr_scheduler def _setup_data( @@ -683,7 +693,7 @@ def _setup_data( ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) sampler = StatefulDistributedSampler( - ds, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle + ds, num_replicas=self.dp_degree, rank=self.dp_rank, shuffle=shuffle ) dataloader = StatefulDataLoader( @@ -709,92 +719,21 @@ def save_checkpoint( epoch: int, ) -> None: """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Model weights with key training.MODEL_KEY - - Relevant recipe state if training is not complete - - Checkpointer will save the model weights and recipe state in - different checkpoint files. To correctly resume training from an intermediate checkpoint, - the model weights and recipe state must be provided. + Checkpoint the state of the recipe using the CheckpointClient. """ - # final dict passed onto the checkpointer - checkpoint_dict = {} - - intermediate_checkpoint = epoch + 1 < self.total_epochs - - if self._is_rank_zero: - log.info( - "Saving checkpoint. This may take some time. Retrieving full model state dict..." - ) - start = time.perf_counter() - - # To prevent GPU memory from spiking during checkpoint save, - # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.gather_cpu_state_dict( - self._model, - self._is_rank_zero, - device=self._device, + self._checkpoint_client.save_checkpoint( + model=self._model, + optimizer=self._optimizer, + training_progress=TrainingProgress( + seed=self.seed, + epochs_run=self.epochs_run, + total_epochs=self.total_epochs, + max_steps_per_epoch=self.max_steps_per_epoch, + dataloader_state_dict=self._dataloader.state_dict(), + ), + epoch=epoch, ) - if self._is_rank_zero: - log.info( - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" - ) - - if intermediate_checkpoint: - start = time.perf_counter() - utils.log_rank_zero(log, "Getting optimizer state dict...") - if not self._optimizer_in_bwd: - opt_state_dict = training.get_full_optimizer_state_dict( - self._model, - self._optimizer, - self._is_rank_zero, - device=self._device, - ) - else: - opt_state_dict = {} - for param, opt in self._optim_ckpt_wrapper.optim_map.items(): - opt_state_dict[param] = training.get_full_optimizer_state_dict( - self._model, opt, self._is_rank_zero, device=self._device - ) - utils.log_rank_zero( - log, - f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", - ) - else: - opt_state_dict = None - - # Now that we have the model and opt state dict, create the actual checkpoint dict - # to be sent to the checkpointer and ultimately written to file - - if self._is_rank_zero: - start = time.perf_counter() - checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) - - # if training is in-progress, checkpoint the optimizer state and recipe state - # as well. - if intermediate_checkpoint: - checkpoint_dict.update( - { - training.OPT_KEY: opt_state_dict, - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, - training.DATALOADER_KEY: self._dataloader.state_dict(), - } - ) - - self._checkpointer.save_checkpoint( - checkpoint_dict, - epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, - ) - log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") - - torch.distributed.barrier() - def concatenated_forward( self, model: nn.Module, @@ -858,17 +797,17 @@ def train(self) -> None: t0 = time.perf_counter() # Running metrics - running_loss = 0 + running_loss = torch.tensor(0.0, device=self._device) running_metrics = { - "rewards/chosen": 0, - "rewards/rejected": 0, - "rewards/accuracies": 0, - "log_probs/chosen": 0, - "log_probs/rejected": 0, - "logits/chosen": 0, - "logits/rejected": 0, + "rewards/chosen": torch.tensor(0.0, device=self._device), + "rewards/rejected": torch.tensor(0.0, device=self._device), + "rewards/accuracies": torch.tensor(0.0, device=self._device), + "log_probs/chosen": torch.tensor(0.0, device=self._device), + "log_probs/rejected": torch.tensor(0.0, device=self._device), + "logits/chosen": torch.tensor(0.0, device=self._device), + "logits/rejected": torch.tensor(0.0, device=self._device), } - num_tokens = 0 + num_tokens = torch.tensor(0, device=self._device) self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint @@ -888,7 +827,7 @@ def train(self) -> None: break # batch is input_ids, labels - num_tokens += torch.tensor(batch[0].numel()) + num_tokens += torch.tensor(batch[0].numel(), device=self._device) policy_chosen_rejected_outputs = self.concatenated_forward( self._model, batch ) @@ -972,12 +911,25 @@ def train(self) -> None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), - ).full_tensor() + ) + # If sharded, collect the DTensor here + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter + if ( + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self._model) + # Step the learning rate scheduler if self._lr_scheduler is not None: self._lr_scheduler.step() @@ -992,51 +944,23 @@ def train(self) -> None: self.global_step % self._log_every_n_steps == 0 and self._is_rank_zero ): - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), + self._log_training_metrics( + loss_to_log=loss_to_log, + running_metrics=running_metrics, + num_tokens=num_tokens, + time_per_step=time.perf_counter() - t0, + grad_norm=( + grad_norm if self._clip_grad_norm is not None else None ), - "tokens_per_second_per_gpu": num_tokens - / (time_per_step * self.world_size), - "rewards/chosen": running_metrics["rewards/chosen"].cpu(), - "rewards/rejected": running_metrics[ - "rewards/rejected" - ].cpu(), - "rewards/accuracies": running_metrics[ - "rewards/accuracies" - ].cpu(), - "rewards/margins": ( - running_metrics["rewards/chosen"] - - running_metrics["rewards/rejected"] - ).cpu(), - "log_probs/chosen": running_metrics[ - "log_probs/chosen" - ].cpu(), - "log_probs/rejected": running_metrics[ - "log_probs/rejected" - ].cpu(), - "logits/chosen": running_metrics["logits/chosen"].cpu(), - "logits/rejected": running_metrics["logits/rejected"].cpu(), - } - if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, ) # Reset running stats for the next step - running_loss = 0 - running_metrics = {key: 0 for key in running_metrics} - num_tokens = 0 + running_loss = torch.tensor(0.0, device=self._device) + running_metrics = { + key: torch.tensor(0.0, device=self._device) + for key in running_metrics + } + num_tokens = torch.tensor(0, device=self._device) t0 = time.perf_counter() @@ -1050,6 +974,42 @@ def train(self) -> None: self._profiler.stop() + def _log_training_metrics( + self, + loss_to_log: float, + running_metrics: Dict[str, torch.Tensor], + num_tokens: torch.Tensor, + time_per_step: float, + grad_norm: Optional[torch.Tensor] = None, + ) -> None: + """Log training metrics to the metric logger.""" + log_dict = { + "loss": loss_to_log, + "lr": get_lr(self._optimizer), + "tokens_per_second_per_gpu": num_tokens / (time_per_step * self.world_size), + "rewards/chosen": running_metrics["rewards/chosen"].cpu(), + "rewards/rejected": running_metrics["rewards/rejected"].cpu(), + "rewards/accuracies": running_metrics["rewards/accuracies"].cpu(), + "rewards/margins": ( + running_metrics["rewards/chosen"] - running_metrics["rewards/rejected"] + ).cpu(), + "log_probs/chosen": running_metrics["log_probs/chosen"].cpu(), + "log_probs/rejected": running_metrics["log_probs/rejected"].cpu(), + "logits/chosen": running_metrics["logits/chosen"].cpu(), + "logits/rejected": running_metrics["logits/rejected"].cpu(), + } + + if grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + + if self._log_peak_memory_stats: + log_dict.update(training.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() @@ -1065,20 +1025,7 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not training.is_distributed(): - raise RuntimeError( - "Distributed finetune recipe should be run via a distributed launcher." - "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" - ) - - init_process_group("cuda:nccl,cpu:gloo") - if cfg.get("fsdp_cpu_offload", False): - # Utilize all available CPU cores for intra-op parallelism. This provides ~2x - # speed up when benchmarking fused AdamW on CPU - training.set_torch_num_threads() - config.log_config(recipe_name="FullDPORecipeDistributed", cfg=cfg) - recipe = FullDPORecipeDistributed(cfg=cfg) recipe.setup(cfg=cfg) recipe.train()