Skip to content

Commit

Permalink
[Distributed] create model on meta device (#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 authored Oct 2, 2024
1 parent 77bac00 commit 8fcb3ba
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 35 deletions.
26 changes: 14 additions & 12 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,26 @@ def main(args):
config.stage_idx = pp_rank
config.n_stages = pp_degree

with device:
with torch.device("meta"):
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
# (Surprisingly, this works even though model is on meta device and mesh is of
# cuda devices)
model.distribute(tp_mesh)
if rank == 0:
logger.info(f"Model: {model}")

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
Expand All @@ -352,17 +363,8 @@ def main(args):
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)
model.to(device)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)
with device:
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# info on stage size and params
stage_size = get_module_size(model)
Expand Down
17 changes: 5 additions & 12 deletions torchchat/distributed/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,17 @@
logger = SingletonLogger.get_logger()



def is_dtensor(tensor):
"""Check if a tensor is a DTensor by class or has a placements attribute (not sure if we want to use attr check)"""
return isinstance(tensor, DTensor) or hasattr(tensor, "placements")


def load_into_dtensor(weight_tensor, model_dtensor):
def convert_to_dtensor(weight_tensor, dtensor_template):
"""Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it"""
weight_tensor = weight_tensor.to(model_dtensor.device)

if weight_tensor.shape != model_dtensor.shape:
if weight_tensor.shape != dtensor_template.shape:
raise ValueError(
f"Shape mismatch: weight tensor shape {weight_tensor.shape} "
f"doesn't match DTensor shape {model_dtensor.shape}"
f"doesn't match DTensor shape {dtensor_template.shape}"
)

placements = model_dtensor.placements
mesh = model_dtensor.device_mesh
placements = dtensor_template.placements
mesh = dtensor_template.device_mesh
mesh_dims = mesh.ndim

for placement in placements:
Expand Down
19 changes: 8 additions & 11 deletions torchchat/distributed/safetensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from torch.nn import Module
from typing import Dict, Tuple, Set, Optional


from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor
from torch.distributed._tensor import DTensor
from torchchat.distributed.dtensor_utils import convert_to_dtensor


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
Expand Down Expand Up @@ -284,9 +284,7 @@ def update_state_dict(
continue

checkpoint_tensor = checkpoint[old_param]
stage_tensor = state_dict[param]

stage_is_dtensor = is_dtensor(stage_tensor)
model_tensor = state_dict[param]

if "wq" in param:
checkpoint_tensor = permute_weight_to_attn_heads(
Expand All @@ -297,17 +295,16 @@ def update_state_dict(
checkpoint_tensor, num_local_heads, head_dim, dim
)

# Move checkpoint tensor to desired device
checkpoint_tensor = checkpoint_tensor.to(device)

# here we need to check if the tensor is a DTensor and if so, adjust the
# shape and placement to match the model DTensor.
if stage_is_dtensor:
model_tensor = load_into_dtensor(checkpoint_tensor, stage_tensor)
# logger.info(f"DTensor: Loaded {param} into {model_tensor=}")
state_dict[param] = model_tensor
if isinstance(model_tensor, DTensor):
state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor)
count_dtensors_loaded += 1

else:
# regular tensor, just update directly
checkpoint_tensor = checkpoint_tensor.to(device)
state_dict[param] = checkpoint_tensor

# ensure matching dtypes
Expand Down

0 comments on commit 8fcb3ba

Please sign in to comment.