diff --git a/dist_run.py b/dist_run.py index 3666bca89..f09261da4 100644 --- a/dist_run.py +++ b/dist_run.py @@ -328,15 +328,26 @@ def main(args): config.stage_idx = pp_rank config.n_stages = pp_degree - with device: + with torch.device("meta"): # TODO: we should create model instead of Transformer model = Transformer(config) # Distribute model on TP mesh + # (Surprisingly, this works even though model is on meta device and mesh is of + # cuda devices) model.distribute(tp_mesh) if rank == 0: logger.info(f"Model: {model}") + # Load weights + logger.info(f"Loading weights for {pp_rank=} on {device=}") + with CUDATrackTime() as timer: + _load_model_weights(model, distribution, device=device, model_config=config) + + logger.info( + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + # Batch size. Since we push batches dynamically through the pipeline rather # than chunking them, this is effectively micro-batch size in pipeline # sense. Thus it is interchangeable with micro-batch size below. @@ -352,17 +363,8 @@ def main(args): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) - - # Load weights - logger.info(f"Loading weights for {pp_rank=} on {device=}") - with CUDATrackTime() as timer: - _load_model_weights(model, distribution, device=device, model_config=config) - model.to(device) - - logger.info( - f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" - ) + with device: + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) # info on stage size and params stage_size = get_module_size(model) diff --git a/torchchat/distributed/dtensor_utils.py b/torchchat/distributed/dtensor_utils.py index 9e57da428..1a6704caa 100644 --- a/torchchat/distributed/dtensor_utils.py +++ b/torchchat/distributed/dtensor_utils.py @@ -8,24 +8,17 @@ logger = SingletonLogger.get_logger() - -def is_dtensor(tensor): - """Check if a tensor is a DTensor by class or has a placements attribute (not sure if we want to use attr check)""" - return isinstance(tensor, DTensor) or hasattr(tensor, "placements") - - -def load_into_dtensor(weight_tensor, model_dtensor): +def convert_to_dtensor(weight_tensor, dtensor_template): """Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it""" - weight_tensor = weight_tensor.to(model_dtensor.device) - if weight_tensor.shape != model_dtensor.shape: + if weight_tensor.shape != dtensor_template.shape: raise ValueError( f"Shape mismatch: weight tensor shape {weight_tensor.shape} " - f"doesn't match DTensor shape {model_dtensor.shape}" + f"doesn't match DTensor shape {dtensor_template.shape}" ) - placements = model_dtensor.placements - mesh = model_dtensor.device_mesh + placements = dtensor_template.placements + mesh = dtensor_template.device_mesh mesh_dims = mesh.ndim for placement in placements: diff --git a/torchchat/distributed/safetensor_utils.py b/torchchat/distributed/safetensor_utils.py index 39eaee71b..01c5091b1 100644 --- a/torchchat/distributed/safetensor_utils.py +++ b/torchchat/distributed/safetensor_utils.py @@ -13,8 +13,8 @@ from torch.nn import Module from typing import Dict, Tuple, Set, Optional - -from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor +from torch.distributed._tensor import DTensor +from torchchat.distributed.dtensor_utils import convert_to_dtensor _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" @@ -284,9 +284,7 @@ def update_state_dict( continue checkpoint_tensor = checkpoint[old_param] - stage_tensor = state_dict[param] - - stage_is_dtensor = is_dtensor(stage_tensor) + model_tensor = state_dict[param] if "wq" in param: checkpoint_tensor = permute_weight_to_attn_heads( @@ -297,17 +295,16 @@ def update_state_dict( checkpoint_tensor, num_local_heads, head_dim, dim ) + # Move checkpoint tensor to desired device + checkpoint_tensor = checkpoint_tensor.to(device) + # here we need to check if the tensor is a DTensor and if so, adjust the # shape and placement to match the model DTensor. - if stage_is_dtensor: - model_tensor = load_into_dtensor(checkpoint_tensor, stage_tensor) - # logger.info(f"DTensor: Loaded {param} into {model_tensor=}") - state_dict[param] = model_tensor + if isinstance(model_tensor, DTensor): + state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor) count_dtensors_loaded += 1 - else: # regular tensor, just update directly - checkpoint_tensor = checkpoint_tensor.to(device) state_dict[param] = checkpoint_tensor # ensure matching dtypes