Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] create model on meta device #1227

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,26 @@ def main(args):
config.stage_idx = pp_rank
config.n_stages = pp_degree

with device:
with torch.device("meta"):
# TODO: we should create model instead of Transformer
model = Transformer(config)

# Distribute model on TP mesh
# (Surprisingly, this works even though model is on meta device and mesh is of
# cuda devices)
model.distribute(tp_mesh)
if rank == 0:
logger.info(f"Model: {model}")

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)

# Batch size. Since we push batches dynamically through the pipeline rather
# than chunking them, this is effectively micro-batch size in pipeline
# sense. Thus it is interchangeable with micro-batch size below.
Expand All @@ -352,17 +363,8 @@ def main(args):
# lanes.
# TODO: bump up the lane count
pipeline_lanes = 1
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
with CUDATrackTime() as timer:
_load_model_weights(model, distribution, device=device, model_config=config)
model.to(device)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
)
with device:
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)

# info on stage size and params
stage_size = get_module_size(model)
Expand Down
17 changes: 5 additions & 12 deletions torchchat/distributed/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,17 @@
logger = SingletonLogger.get_logger()



def is_dtensor(tensor):
"""Check if a tensor is a DTensor by class or has a placements attribute (not sure if we want to use attr check)"""
return isinstance(tensor, DTensor) or hasattr(tensor, "placements")


def load_into_dtensor(weight_tensor, model_dtensor):
def convert_to_dtensor(weight_tensor, dtensor_template):
"""Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it"""
weight_tensor = weight_tensor.to(model_dtensor.device)

if weight_tensor.shape != model_dtensor.shape:
if weight_tensor.shape != dtensor_template.shape:
raise ValueError(
f"Shape mismatch: weight tensor shape {weight_tensor.shape} "
f"doesn't match DTensor shape {model_dtensor.shape}"
f"doesn't match DTensor shape {dtensor_template.shape}"
)

placements = model_dtensor.placements
mesh = model_dtensor.device_mesh
placements = dtensor_template.placements
mesh = dtensor_template.device_mesh
mesh_dims = mesh.ndim

for placement in placements:
Expand Down
19 changes: 8 additions & 11 deletions torchchat/distributed/safetensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from torch.nn import Module
from typing import Dict, Tuple, Set, Optional


from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor
from torch.distributed._tensor import DTensor
from torchchat.distributed.dtensor_utils import convert_to_dtensor


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
Expand Down Expand Up @@ -284,9 +284,7 @@ def update_state_dict(
continue

checkpoint_tensor = checkpoint[old_param]
stage_tensor = state_dict[param]

stage_is_dtensor = is_dtensor(stage_tensor)
model_tensor = state_dict[param]

if "wq" in param:
checkpoint_tensor = permute_weight_to_attn_heads(
Expand All @@ -297,17 +295,16 @@ def update_state_dict(
checkpoint_tensor, num_local_heads, head_dim, dim
)

# Move checkpoint tensor to desired device
checkpoint_tensor = checkpoint_tensor.to(device)

# here we need to check if the tensor is a DTensor and if so, adjust the
# shape and placement to match the model DTensor.
if stage_is_dtensor:
model_tensor = load_into_dtensor(checkpoint_tensor, stage_tensor)
# logger.info(f"DTensor: Loaded {param} into {model_tensor=}")
state_dict[param] = model_tensor
if isinstance(model_tensor, DTensor):
state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor)
count_dtensors_loaded += 1

else:
# regular tensor, just update directly
checkpoint_tensor = checkpoint_tensor.to(device)
state_dict[param] = checkpoint_tensor

# ensure matching dtypes
Expand Down
Loading