Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initializing server with an adapter sets it as the default #370

Merged
merged 7 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .github/workflows/server_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ jobs:
echo "files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -E '*.py$' | tr '\n' ' ')"
echo "files=$(git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -E '*.py$' | tr '\n' ' ')" >> $GITHUB_OUTPUT

- name: Run flake8 on changed files
if: steps.changed_files.outputs.files != ''
run: |
pip install flake8
echo running linter on: ${{ steps.changed_files.outputs.files }}
flake8 ${{ steps.changed_files.outputs.files }}
# TODO(travis): reenable after running this on the entire codebase
# - name: Run flake8 on changed files
# if: steps.changed_files.outputs.files != ''
# run: |
# pip install flake8
# echo running linter on: ${{ steps.changed_files.outputs.files }}
# flake8 ${{ steps.changed_files.outputs.files }}

- name: Install Protoc
uses: arduino/setup-protoc@v1
Expand Down
4 changes: 2 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// Batching and inference logic
use crate::adapter::{extract_adapter_params, Adapter};
use crate::adapter::{extract_adapter_params, Adapter, BASE_MODEL_ADAPTER_ID};
use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
Expand Down Expand Up @@ -71,7 +71,7 @@ impl Infer {
// Initialize with base model adapter (empty) mapping to index 0
let adapter_to_index = Arc::new(Mutex::new(HashMap::from([(
AdapterParameters {
adapter_ids: vec!["".to_string()],
adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()],
..Default::default()
},
0,
Expand Down
31 changes: 31 additions & 0 deletions server/lorax_server/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

from huggingface_hub import HfApi

from lorax_server.utils.sources import HUB, PBASE, S3, get_model_source, map_pbase_model_id_to_s3
from lorax_server.utils.weights import download_weights


def download_adapter(
adapter_id: str,
adapter_source: str,
api_token: Optional[str] = None,
) -> int:
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token)
adapter_source = S3

if adapter_source == HUB:
# Quick auth check on the repo against the token
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
source.load_config()

download_weights(
adapter_id, source=adapter_source, api_token=api_token
)

# Calculate size of adapter to be loaded
return source.get_weight_bytes()
92 changes: 2 additions & 90 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Optional
from enum import Enum

from lorax_server.utils.weights import download_weights as _download_weights


app = typer.Typer()

Expand Down Expand Up @@ -91,96 +93,6 @@ def serve(
)


def _download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
auto_convert: bool = True,
source: str = "hub",
api_token: Optional[str] = None,
):
# Import here after the logger is added to log potential import exceptions
from lorax_server import utils
from lorax_server.utils import sources
model_source = sources.get_model_source(source, model_id, revision, extension, api_token)

# Test if files were already download
try:
model_source.weight_files()
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError):
pass

is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None

if not is_local_model:
# TODO: Combine into class that takes the source as input
# Try to download weights from the hub
try:
model_source.download_model_assets()
return
# No weights found on the hub with this extension
except utils.EntryNotFoundError as e:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e

# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = model_source.weight_files(extension=".bin")

# No local pytorch weights
except utils.LocalEntryNotFoundError:
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Downloading PyTorch weights."
)

# Try to see if there are pytorch weights on the hub
pt_filenames = model_source.remote_weight_files(extension=".bin")
# Download pytorch weights
local_pt_files = model_source.download_weights(pt_filenames)

if auto_convert:
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)

# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
from transformers import AutoConfig
import transformers

config_path = sources.get_config_path(model_id, source)
config = AutoConfig.from_pretrained(
config_path,
revision=revision,
)
architecture = config.architectures[0]

class_ = getattr(transformers, architecture)

# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

except Exception as e:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)


@app.command()
def download_weights(
model_id: str,
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lorax_server.utils.graph import GraphCache
from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.sources import HUB
from lorax_server.utils.state import warmup_mode
from lorax_server.utils.tokenizer import TokenizerManager

Expand Down Expand Up @@ -731,7 +732,7 @@ def __init__(
sliding_window: Optional[int] = None,
compile: bool = False,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
dynamic_adapter_loading_enabled: bool = True,
adapter_source: str = HUB,
):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
Expand All @@ -751,7 +752,8 @@ def __init__(
world_size=world_size,
sliding_window=sliding_window,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
dynamic_adapter_loading_enabled=True,
)

if sliding_window is not None:
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
GemmaConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -63,29 +62,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -107,7 +88,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
23 changes: 1 addition & 22 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
LM_HEAD,
)
from lorax_server.utils import (
compute_delta_weight,
create_merged_weight_files,
get_start_stop_idxs_for_rank,
initialize_torch_distributed,
load_module_map,
weight_files,
Weights,
)
Expand Down Expand Up @@ -70,23 +66,6 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
Expand Down Expand Up @@ -114,7 +93,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
LlamaConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -64,29 +63,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -108,7 +89,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
MistralConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -61,29 +60,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -106,7 +87,7 @@ def __init__(
sliding_window=config.sliding_window,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
Loading
Loading