Skip to content

Commit

Permalink
Fix all devices occupation when applying tp to torch engine by updati…
Browse files Browse the repository at this point in the history
…ng device map (#1172)
  • Loading branch information
grimoire authored Feb 28, 2024
1 parent 9d539af commit a5ff047
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,32 @@ def raise_error(self, default_error: Exception):
raise err


def _get_model_memory_usage(model: torch.nn.Module) -> int:
"""get model memory usage."""
size = 0
for _, param in model.named_parameters():
size += param.element_size() * param.numel()
for _, buf in model.named_buffers():
size += buf.element_size() * param.numel()
return size


def _create_device_map(model: torch.nn.Module,
world_size: int,
device_map: dict = None):
"""Distribute params to each devices."""
if device_map is None:
device_map = dict()
device_id = 0
for name, _ in model.named_parameters():
device_map[name] = device_id
device_id = (device_id + 1) % world_size
for name, _ in model.named_buffers():
device_map[name] = device_id
device_id = (device_id + 1) % world_size
return device_map


def _tp_build_model(
rank: int,
model_path: str,
Expand All @@ -585,6 +611,17 @@ def _tp_build_model(
patched_model = None
cache_engine = None

def __get_device_map(model, device_map=None):
"""get device map of model."""
import psutil
model_size = _get_model_memory_usage(model)
if psutil.virtual_memory().available < model_size:
logger.debug('Preload model on GPU.')
return device_map
else:
logger.debug('Preload model on CPU.')
return 'cpu'

def __load_params_and_buffers(param_mod, mod):
"""load param and buffer."""
for name, param in param_mod.named_parameters(recurse=False):
Expand Down Expand Up @@ -629,18 +666,24 @@ def _broadcast_config(cache_config):
try:
config = model_config.hf_config
torch_dtype = model_config.dtype
device_map = None
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code)
if rank == 0:
device_map = _create_device_map(model, world_size)
_add_adapters(model, adapters)
if rank == 0:
# adapter would remove weight of linear.
device_map = _create_device_map(model, world_size, device_map)
model.eval()
model.config.use_cache = True

if rank == 0:
with LoadNoInit():
device_map = 'auto'
device_map = __get_device_map(model, device_map)
param_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
Expand Down

0 comments on commit a5ff047

Please sign in to comment.