Skip to content

Commit

Permalink
feat: support lazy loading the lora module for reducing the loading p… (
Browse files Browse the repository at this point in the history
  • Loading branch information
thincal authored Aug 2, 2024
1 parent e594ce0 commit 2e47e77
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
5 changes: 3 additions & 2 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pad_rank,
use_cutlass_shrink,
)
from lorax_server.utils.weights import load_module_weight

if TYPE_CHECKING:
from lorax_server.models.model import Model
Expand Down Expand Up @@ -166,10 +167,10 @@ def load(
return None

lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_a = load_module_weight(lora_a_name, lora_a, base_device, model.dtype)

lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
lora_b = load_module_weight(lora_b_name, lora_b, base_device, model.dtype)

scale = get_scaling_factor(
config.lora_alpha,
Expand Down
13 changes: 12 additions & 1 deletion server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Set, Tuple

from loguru import logger
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -78,6 +79,7 @@ def _load_and_merge(
weight_names,
api_token,
trust_remote_code,
False,
)

adapters_to_merge.append((module_map, adapter_config))
Expand Down Expand Up @@ -136,6 +138,7 @@ def load_module_map(
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
lazy_load_weights: bool = True,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
# TODO(geoffrey): refactor this and merge parts of this function with
# lorax_server/utils/adapter.py::create_merged_weight_files
Expand All @@ -157,7 +160,15 @@ def load_module_map(
adapter_filenames = source.weight_files()
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
if lazy_load_weights:
result = {}
# just fetching the layer names of the module
with safe_open(filename, framework="pt") as f:
for k in f.keys():
result[k] = filename
adapter_weights.update(result)
else:
adapter_weights.update(load_file(filename))

# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)
Expand Down
19 changes: 16 additions & 3 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ def get_slice(self, tensor_name: str) -> torch.Tensor:

def get_tensor(self, tensor_name: str) -> torch.Tensor:
tensor = self.weights[tensor_name]
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
return load_module_weight(tensor_name, tensor, self.device, self.dtype)

def get_slice_shape(self, slice) -> torch.Size:
return slice.shape
Expand Down Expand Up @@ -542,3 +540,18 @@ def download_weights(
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)


def load_module_weight(name: str, module: Union[torch.Tensor, str], device, dtype):
if isinstance(module, torch.Tensor):
return module.to(device, dtype)

if isinstance(device, torch.device):
if device.type == "cuda":
device = device.index
elif device.type == "cpu":
device = "cpu"

# module would be just the filename if lazy loading happened before
with safe_open(module, framework="pt", device=device) as f:
return f.get_tensor(name).to(dtype)

0 comments on commit 2e47e77

Please sign in to comment.