From 837a1e30769260c400fe1f7b4b89eceb9a5c14cc Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 16:13:06 -0700 Subject: [PATCH 1/7] WIP: default adapter --- server/lorax_server/models/flash_causal_lm.py | 6 ++++-- server/lorax_server/models/flash_gemma.py | 21 +------------------ server/lorax_server/models/flash_gpt2.py | 20 +----------------- server/lorax_server/models/flash_llama.py | 21 +------------------ server/lorax_server/models/flash_mistral.py | 21 +------------------ server/lorax_server/models/flash_mixtral.py | 21 +------------------ server/lorax_server/models/flash_phi.py | 21 +------------------ server/lorax_server/models/flash_qwen.py | 21 +------------------ server/lorax_server/models/flash_qwen2.py | 21 +------------------ server/lorax_server/models/model.py | 12 +++++++++++ server/lorax_server/server.py | 20 ++---------------- 11 files changed, 26 insertions(+), 179 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2f53ffd6c..4ec97c5d6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -32,6 +32,7 @@ from lorax_server.utils.graph import GraphCache from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments +from lorax_server.utils.sources import HUB from lorax_server.utils.state import warmup_mode from lorax_server.utils.tokenizer import TokenizerManager @@ -731,7 +732,7 @@ def __init__( sliding_window: Optional[int] = None, compile: bool = False, adapter_id: str = BASE_MODEL_ADAPTER_ID, - dynamic_adapter_loading_enabled: bool = True, + adapter_source: str = HUB, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -751,7 +752,8 @@ def __init__( world_size=world_size, sliding_window=sliding_window, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, + dynamic_adapter_loading_enabled=True, ) if sliding_window is not None: diff --git a/server/lorax_server/models/flash_gemma.py b/server/lorax_server/models/flash_gemma.py index c1e17aafd..0030907d9 100644 --- a/server/lorax_server/models/flash_gemma.py +++ b/server/lorax_server/models/flash_gemma.py @@ -12,7 +12,6 @@ GemmaConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -63,29 +62,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -107,7 +88,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 388452b18..cfedac509 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -20,7 +20,6 @@ ) from lorax_server.utils import ( compute_delta_weight, - create_merged_weight_files, get_start_stop_idxs_for_rank, initialize_torch_distributed, load_module_map, @@ -70,23 +69,6 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, @@ -114,7 +96,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 07632613b..82dde9199 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -13,7 +13,6 @@ LlamaConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -64,29 +63,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -108,7 +89,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 97df03804..b4fa228b5 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -12,7 +12,6 @@ MistralConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -61,29 +60,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -106,7 +87,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 76772ce90..5cd7babfe 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -31,7 +31,6 @@ MixtralConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -361,29 +360,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -406,7 +387,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index ac0be9435..f7d0156d2 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -18,7 +18,6 @@ PhiConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -69,29 +68,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -114,7 +95,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index 55e439fc6..01013b0eb 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -17,7 +17,6 @@ QwenConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -68,29 +67,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -113,7 +94,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 693b1e381..1b3161c41 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -19,7 +19,6 @@ FlashQwen2ForCausalLM, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -69,29 +68,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -116,7 +97,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 047bf8cd4..228b15fa8 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -7,12 +7,14 @@ from typing import Dict, List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase +from lorax_server.adapters.utils import download_adapter from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse from lorax_server.utils.adapter import ( BASE_MODEL_ADAPTER_ID, load_and_merge_adapters, ) +from lorax_server.utils.sources import HUB from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.adapters.weights import LayerAdapterWeights from lorax_server.utils.weights import shard_on_dim @@ -33,6 +35,7 @@ def __init__( world_size: int = 1, sliding_window: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + adapter_source: str = HUB, dynamic_adapter_loading_enabled: bool = True, ): self.model_id = model_id @@ -59,6 +62,15 @@ def __init__( is not None ) + if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: + download_adapter(adapter_id, adapter_source, api_token=None) + self.load_adapter( + AdapterParameters(adapter_ids=[adapter_id]), + adapter_source, + adapter_index=0, + api_token=None, + ) + self.check_initialized() @property diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 6aff3f372..0217d9d5b 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import List, Optional +from lorax_server.adapters.utils import download_adapter from lorax_server.cache import Cache from lorax_server.cli import _download_weights from lorax_server.interceptor import ExceptionInterceptor @@ -142,24 +143,7 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co logger.info("No adapter to download for base model. Skipping.") continue - if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) - adapter_source = S3 - - if adapter_source == HUB: - # Quick auth check on the repo against the token - HfApi(token=api_token).model_info(adapter_id, revision=None) - - # fail fast if ID is not an adapter (i.e. it is a full model) - source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) - source.load_config() - - _download_weights( - adapter_id, source=adapter_source, api_token=api_token - ) - - # Calculate size of adapter to be loaded - adapter_bytes += source.get_weight_bytes() + adapter_bytes += download_adapter(adapter_id, adapter_source, api_token) adapter_memory_size = self.model.adapter_memory_size() if adapter_memory_size > 0: From 26f6d412708958ff8cf0f24930959c4daef3053c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 16:43:38 -0700 Subject: [PATCH 2/7] Fixed imports --- server/lorax_server/cli.py | 92 +---------- server/lorax_server/models/flash_gpt2.py | 3 - server/lorax_server/models/flash_mixtral.py | 1 + server/lorax_server/server.py | 1 - server/lorax_server/utils/__init__.py | 4 - server/lorax_server/utils/adapter.py | 170 +------------------- server/lorax_server/utils/sources/source.py | 4 +- server/lorax_server/utils/weights.py | 90 +++++++++++ 8 files changed, 98 insertions(+), 267 deletions(-) diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 7ad076bbc..3a7ba4652 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -7,6 +7,8 @@ from typing import Optional from enum import Enum +from lorax_server.utils.weights import download_weights as _download_weights + app = typer.Typer() @@ -91,96 +93,6 @@ def serve( ) -def _download_weights( - model_id: str, - revision: Optional[str] = None, - extension: str = ".safetensors", - auto_convert: bool = True, - source: str = "hub", - api_token: Optional[str] = None, -): - # Import here after the logger is added to log potential import exceptions - from lorax_server import utils - from lorax_server.utils import sources - model_source = sources.get_model_source(source, model_id, revision, extension, api_token) - - # Test if files were already download - try: - model_source.weight_files() - logger.info("Files are already present on the host. " "Skipping download.") - return - # Local files not found - except (utils.LocalEntryNotFoundError, FileNotFoundError): - pass - - is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( - "WEIGHTS_CACHE_OVERRIDE", None - ) is not None - - if not is_local_model: - # TODO: Combine into class that takes the source as input - # Try to download weights from the hub - try: - model_source.download_model_assets() - return - # No weights found on the hub with this extension - except utils.EntryNotFoundError as e: - # Check if we want to automatically convert to safetensors or if we can use .bin weights instead - if not extension == ".safetensors" or not auto_convert: - raise e - - # Try to see if there are local pytorch weights - try: - # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE - local_pt_files = model_source.weight_files(extension=".bin") - - # No local pytorch weights - except utils.LocalEntryNotFoundError: - if extension == ".safetensors": - logger.warning( - f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Downloading PyTorch weights." - ) - - # Try to see if there are pytorch weights on the hub - pt_filenames = model_source.remote_weight_files(extension=".bin") - # Download pytorch weights - local_pt_files = model_source.download_weights(pt_filenames) - - if auto_convert: - logger.warning( - f"No safetensors weights found for model {model_id} at revision {revision}. " - f"Converting PyTorch weights to safetensors." - ) - - # Safetensors final filenames - local_st_files = [ - p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" - for p in local_pt_files - ] - try: - from transformers import AutoConfig - import transformers - - config_path = sources.get_config_path(model_id, source) - config = AutoConfig.from_pretrained( - config_path, - revision=revision, - ) - architecture = config.architectures[0] - - class_ = getattr(transformers, architecture) - - # Name for this varible depends on transformers version. - discard_names = getattr(class_, "_tied_weights_keys", []) - discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) - - except Exception as e: - discard_names = [] - # Convert pytorch weights to safetensors - utils.convert_files(local_pt_files, local_st_files, discard_names) - - @app.command() def download_weights( model_id: str, diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index cfedac509..426e18fd8 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -19,10 +19,7 @@ LM_HEAD, ) from lorax_server.utils import ( - compute_delta_weight, - get_start_stop_idxs_for_rank, initialize_torch_distributed, - load_module_map, weight_files, Weights, ) diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 5cd7babfe..6c89c2b81 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -211,6 +211,7 @@ def from_pb( max_length = max(max_length, input_length + max_new_tokens) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) + print("!!! ADAPTER INDICES", adapter_indices) request_tokenizers = [ tokenizers.get_tokenizer(r.adapter_index, tokenizer) diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 0217d9d5b..5a66fcc14 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -14,7 +14,6 @@ from lorax_server.adapters.utils import download_adapter from lorax_server.cache import Cache -from lorax_server.cli import _download_weights from lorax_server.interceptor import ExceptionInterceptor from lorax_server.models import Model, get_model from lorax_server.pb import generate_pb2_grpc, generate_pb2 diff --git a/server/lorax_server/utils/__init__.py b/server/lorax_server/utils/__init__.py index 910ae613f..15f41a0ff 100644 --- a/server/lorax_server/utils/__init__.py +++ b/server/lorax_server/utils/__init__.py @@ -1,6 +1,4 @@ from lorax_server.utils.adapter import ( - compute_delta_weight, - create_merged_weight_files, load_module_map, ) from lorax_server.utils.convert import convert_file, convert_files @@ -33,8 +31,6 @@ ) __all__ = [ - "compute_delta_weight", - "create_merged_weight_files", "load_module_map", "convert_file", "convert_files", diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index 3c3b7b4c6..f73ab992f 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -1,24 +1,14 @@ from dataclasses import dataclass -import os -from collections import defaultdict from functools import lru_cache -from pathlib import Path -from typing import TYPE_CHECKING, List, Dict, Set, Tuple +from typing import TYPE_CHECKING, Set, Tuple import warnings -import torch -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from loguru import logger -from peft.utils import transpose -from safetensors.torch import load_file, save_file +from safetensors.torch import load_file from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer -from tqdm import tqdm -from filelock import FileLock from lorax_server.pb import generate_pb2 -from lorax_server.utils.sources import get_model_source, get_config_path, weight_files +from lorax_server.utils.sources import get_model_source, get_config_path from lorax_server.utils.merges.strategies import merge_adapters -from lorax_server.adapters.lora import get_scaling_factor if TYPE_CHECKING: from lorax_server.adapters.config import AdapterConfig, ModuleMap @@ -145,157 +135,3 @@ def load_module_map( # map the model weights to the relevant adapter weights (LoRA A and B matrices) module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) return module_map, adapter_config, adapter_weight_names, adapter_tokenizer - - -def compute_delta_weight( - lora_A: torch.Tensor, - lora_B: torch.Tensor, - fan_in_fan_out: bool, - alpha: float, - r: float, - uses_rslora: bool = False -) -> torch.Tensor: - """Computes the delta weight for a Linear layer given A and B LoRA matrices. - - TODO: add logic for other module types beyond Linear layers. - - Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806 - """ - scaling = get_scaling_factor(alpha, r, uses_rslora=uses_rslora) - delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling - return delta_weight - - -def merge_adapter_weights( - model_weights: Dict[str, torch.Tensor], - adapter_weights: Dict[str, torch.Tensor], - adapter_config: "AdapterConfig" -) -> Tuple[Dict[str, torch.Tensor], Set[str]]: - """ - Merges the adapter weights into the model weights. - - Args: - model_weights (Dict[str, torch.Tensor]): The weights of the base model. - adapter_weights (Dict[str, torch.Tensor]): The weights of the adapters. - adapter_config (AdapterConfig): The configuration for the adapter. - - Returns: - Tuple[Dict[str, torch.Tensor], Set[str]]: A tuple containing the merged weights and the set of processed adapter weight names. - """ - from lorax_server.adapters.lora import LoraConfig - - if not isinstance(adapter_config, LoraConfig): - raise ValueError(f"Unsupported adapter config type: {type(adapter_config)}") - - module_mapping = defaultdict(dict) - processed_adapter_weight_names = set() - - # map the original tensor names to their adapter counterparts - for weight_name in model_weights: - end_idx = weight_name.rfind(".weight") - key = weight_name[:end_idx] - for adapter_weight_name in adapter_weights: - if key in adapter_weight_name: - # example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight' - # matrix_type gets the second to last element in the module name, i.e. 'lora_B' - matrix_type = adapter_weight_name.split(".")[-2] - module_mapping[weight_name][matrix_type] = adapter_weight_name - processed_adapter_weight_names.add(adapter_weight_name) - - # merge adapter weights into model weights - merged_weights = {} - for weight_name, adapter_weight_names in tqdm( - module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)): - - # TODO: support adapter types beyond LoRA - # TODO: put this on GPU if it is available. This should greatly speedup compute_delta_weight - lora_A = adapter_weights[adapter_weight_names["lora_A"]] - lora_B = adapter_weights[adapter_weight_names["lora_B"]] - delta_weight = compute_delta_weight( - lora_A, - lora_B, - adapter_config.fan_in_fan_out, - adapter_config.lora_alpha, - adapter_config.r, - uses_rslora=adapter_config.use_rslora, - ) - - # transpose delta weight if necessary - # TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2). - # We can likely take this out once we've switched to using Linear layers. - if (delta_weight.shape != model_weights[weight_name].shape and - delta_weight.T.shape == model_weights[weight_name].shape): - delta_weight = delta_weight.T - merged_weights[weight_name] = model_weights[weight_name] + delta_weight - return merged_weights, processed_adapter_weight_names - - -def create_merged_weight_files( - adapter_id: str, - model_id: str, - model_weight_filenames: List[Path], - adapter_source: str = "hub", -) -> List[Path]: - """Creates merged weight files for the given adapter ID and filenames.""" - api_token = None # TODO(travis): add support for API token - source = get_model_source(adapter_source, adapter_id, api_token=api_token) - adapter_filenames = source.weight_files() - - adapter_config = source.load_config() - if adapter_config.base_model_name_or_path != model_id: - expected_config = AutoConfig.from_pretrained(model_id) - model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path) - if model_config.architectures == expected_config.architectures: - warnings.warn( - f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " - f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead." - ) - else: - # TODO(travis): revisit this when we support clasification heads which will not use CausalLM - raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " - f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " - f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") - - # load adapter weights from all shards (should have relatively small memory footprint) - adapter_weights = {} - for filename in adapter_filenames: - adapter_weights.update(load_file(filename)) - remaining_adapter_weight_names = set(adapter_weights.keys()) - - merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged" - # just grab the existing files if they already exist and return immediately - lock = FileLock(str(merged_weight_directory)+ ".lock") - with lock: - if merged_weight_directory.is_dir(): - logger.info(f"Merged weight directory {merged_weight_directory} exist, skipping merge computation.") - return weight_files(merged_weight_directory) - else: - logger.info("Merged weight files do not exist, computing merge.") - os.makedirs(merged_weight_directory) - - merged_weight_filenames = [] - for i, filename in enumerate(model_weight_filenames): - logger.info( - f"Merging adapter weights into model weights in " - f"{filename} ({i+1} / {len(model_weight_filenames)})..." - ) - model_weights = load_file(filename) - merged_weights, processed_adapter_weight_names = merge_adapter_weights( - model_weights, adapter_weights, adapter_config) - - merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename)) - save_file(merged_weights, merged_adapter_filename) - logger.debug(f"Saved merged weights into {merged_adapter_filename}") - - merged_weight_filenames.append(merged_adapter_filename) - remaining_adapter_weight_names = remaining_adapter_weight_names.difference( - processed_adapter_weight_names) - - if len(remaining_adapter_weight_names) > 0: - logger.warning("WARNING: The following lora weights were not merged into the model weights:") - for lora_name in remaining_adapter_weight_names: - logger.warning("\t" + lora_name) - - logger.info( - f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}") - return merged_weight_filenames diff --git a/server/lorax_server/utils/sources/source.py b/server/lorax_server/utils/sources/source.py index 14867d97f..4ce1081dc 100644 --- a/server/lorax_server/utils/sources/source.py +++ b/server/lorax_server/utils/sources/source.py @@ -4,7 +4,6 @@ from typing import Optional, List from pathlib import Path -from lorax_server.adapters import load_adapter_config from lorax_server.adapters.config import AdapterConfig @@ -132,7 +131,8 @@ def get_weight_bytes(self) -> int: return total_size def load_config(self) -> AdapterConfig: + from lorax_server.adapters import load_adapter_config + config_path = self.download_file("config.json", ignore_errors=True) adapter_config_path = self.download_file("adapter_config.json", ignore_errors=True) return load_adapter_config(config_path, adapter_config_path, self.api_token) - diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index b786751b8..fa6ab19a5 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -365,3 +365,93 @@ def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.Pro raise NotImplementedError("Let's make that generic when needed") return tensor + + +def download_weights( + model_id: str, + revision: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, + source: str = "hub", + api_token: Optional[str] = None, +): + # Import here after the logger is added to log potential import exceptions + from lorax_server import utils + from lorax_server.utils import sources + model_source = sources.get_model_source(source, model_id, revision, extension, api_token) + + # Test if files were already download + try: + model_source.weight_files() + logger.info("Files are already present on the host. " "Skipping download.") + return + # Local files not found + except (utils.LocalEntryNotFoundError, FileNotFoundError): + pass + + is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + # TODO: Combine into class that takes the source as input + # Try to download weights from the hub + try: + model_source.download_model_assets() + return + # No weights found on the hub with this extension + except utils.EntryNotFoundError as e: + # Check if we want to automatically convert to safetensors or if we can use .bin weights instead + if not extension == ".safetensors" or not auto_convert: + raise e + + # Try to see if there are local pytorch weights + try: + # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE + local_pt_files = model_source.weight_files(extension=".bin") + + # No local pytorch weights + except utils.LocalEntryNotFoundError: + if extension == ".safetensors": + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Downloading PyTorch weights." + ) + + # Try to see if there are pytorch weights on the hub + pt_filenames = model_source.remote_weight_files(extension=".bin") + # Download pytorch weights + local_pt_files = model_source.download_weights(pt_filenames) + + if auto_convert: + logger.warning( + f"No safetensors weights found for model {model_id} at revision {revision}. " + f"Converting PyTorch weights to safetensors." + ) + + # Safetensors final filenames + local_st_files = [ + p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" + for p in local_pt_files + ] + try: + from transformers import AutoConfig + import transformers + + config_path = sources.get_config_path(model_id, source) + config = AutoConfig.from_pretrained( + config_path, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception as e: + discard_names = [] + # Convert pytorch weights to safetensors + utils.convert_files(local_pt_files, local_st_files, discard_names) From 74d002e6a9f44d47c8baedc3867bbdb534db2e05 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 16:55:35 -0700 Subject: [PATCH 3/7] Fix base model --- router/src/infer.rs | 4 ++-- server/lorax_server/models/flash_causal_lm.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index c950e311f..6bfd6295f 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,5 +1,5 @@ /// Batching and inference logic -use crate::adapter::{extract_adapter_params, Adapter}; +use crate::adapter::{extract_adapter_params, Adapter, BASE_MODEL_ADAPTER_ID}; use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; @@ -71,7 +71,7 @@ impl Infer { // Initialize with base model adapter (empty) mapping to index 0 let adapter_to_index = Arc::new(Mutex::new(HashMap::from([( AdapterParameters { - adapter_ids: vec!["".to_string()], + adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], ..Default::default() }, 0, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 4ec97c5d6..1de051999 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -271,6 +271,7 @@ def from_pb( max_length = max(max_length, input_length + max_new_tokens) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) + print("!!! ADAPTER INDICES", adapter_indices) request_tokenizers = [ tokenizers.get_tokenizer(r.adapter_index, tokenizer) From 5a8687a75ca89e8e96ca2475f37296a30444643b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 17:01:12 -0700 Subject: [PATCH 4/7] No debug --- server/lorax_server/models/flash_causal_lm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 1de051999..4ec97c5d6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -271,7 +271,6 @@ def from_pb( max_length = max(max_length, input_length + max_new_tokens) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - print("!!! ADAPTER INDICES", adapter_indices) request_tokenizers = [ tokenizers.get_tokenizer(r.adapter_index, tokenizer) From 24409d1463f66642128d98ef9344b616a6d96787 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 21:40:32 -0700 Subject: [PATCH 5/7] Temp disable flake8 --- .github/workflows/server_tests.yaml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/server_tests.yaml b/.github/workflows/server_tests.yaml index 5fdc2b1b2..78498bc01 100644 --- a/.github/workflows/server_tests.yaml +++ b/.github/workflows/server_tests.yaml @@ -31,12 +31,13 @@ jobs: echo "files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -E '*.py$' | tr '\n' ' ')" echo "files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -E '*.py$' | tr '\n' ' ')" >> $GITHUB_OUTPUT - - name: Run flake8 on changed files - if: steps.changed_files.outputs.files != '' - run: | - pip install flake8 - echo running linter on: ${{ steps.changed_files.outputs.files }} - flake8 ${{ steps.changed_files.outputs.files }} + # TODO(travis): reenable after running this on the entire codebase + # - name: Run flake8 on changed files + # if: steps.changed_files.outputs.files != '' + # run: | + # pip install flake8 + # echo running linter on: ${{ steps.changed_files.outputs.files }} + # flake8 ${{ steps.changed_files.outputs.files }} - name: Install Protoc uses: arduino/setup-protoc@v1 From 514eae623e158573aa6d89061c2d915934657ded Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 21:46:39 -0700 Subject: [PATCH 6/7] Added missing file --- server/lorax_server/adapters/utils.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 server/lorax_server/adapters/utils.py diff --git a/server/lorax_server/adapters/utils.py b/server/lorax_server/adapters/utils.py new file mode 100644 index 000000000..88aea4a17 --- /dev/null +++ b/server/lorax_server/adapters/utils.py @@ -0,0 +1,31 @@ +from typing import Optional + +from huggingface_hub import HfApi + +from lorax_server.utils.sources import HUB, PBASE, S3, get_model_source, map_pbase_model_id_to_s3 +from lorax_server.utils.weights import download_weights + + +def download_adapter( + adapter_id: str, + adapter_source: str, + api_token: Optional[str] = None, +) -> int: + if adapter_source == PBASE: + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) + adapter_source = S3 + + if adapter_source == HUB: + # Quick auth check on the repo against the token + HfApi(token=api_token).model_info(adapter_id, revision=None) + + # fail fast if ID is not an adapter (i.e. it is a full model) + source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) + source.load_config() + + download_weights( + adapter_id, source=adapter_source, api_token=api_token + ) + + # Calculate size of adapter to be loaded + return source.get_weight_bytes() From 957d17eac18b58e4c9056a50d611a635196ba12f Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 1 Apr 2024 21:51:57 -0700 Subject: [PATCH 7/7] Removed unused tests --- server/tests/utils/test_adapter.py | 44 ------------------------------ 1 file changed, 44 deletions(-) delete mode 100644 server/tests/utils/test_adapter.py diff --git a/server/tests/utils/test_adapter.py b/server/tests/utils/test_adapter.py deleted file mode 100644 index c7b2f48b9..000000000 --- a/server/tests/utils/test_adapter.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch - -from lorax_server.adapters.lora import LoraConfig -from lorax_server.utils.adapter import merge_adapter_weights - - -def test_merge_adapter_weights(): - W_0 = torch.tensor([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ]) - model_weights = { - "model.layers.10.self_attn.q_proj.weight": W_0 - } - - A = torch.tensor([ - [1, 2, 3], - [4, 5, 6] - ]) - B = torch.tensor([ - [1, 2], - [3, 4], - [5, 6] - ]) - adapter_weights = { - "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight": A, - "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight": B - } - - W_expected = torch.tensor([ - [ 5.5000, 8.0000, 10.5000], - [13.5000, 18.0000, 22.5000], - [21.5000, 28.0000, 34.5000] - ]) - adapter_config = LoraConfig(base_model_name_or_path="", r=2, target_modules=None, lora_alpha=1, fan_in_fan_out=False, use_rslora=False) - merged_weights, processed_adapter_weight_names = merge_adapter_weights(model_weights, adapter_weights, adapter_config) - - assert len(merged_weights) == 1 - assert merged_weights["model.layers.10.self_attn.q_proj.weight"].equal(W_expected) - - assert len(processed_adapter_weight_names) == 2 - assert "base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight" in processed_adapter_weight_names - assert "base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight" in processed_adapter_weight_names \ No newline at end of file