Skip to content

Commit

Permalink
Fixed loading adapter from absolute s3 path (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 4, 2024
1 parent f20789d commit 57d5470
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 8 deletions.
46 changes: 38 additions & 8 deletions server/lorax_server/utils/sources/s3.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,56 @@
import os
import time
from datetime import timedelta
from typing import Optional, List, Any
from typing import TYPE_CHECKING, Optional, List, Any, Tuple

from loguru import logger
from pathlib import Path
import boto3
from botocore.config import Config
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE


from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
)

from .source import BaseModelSource, try_to_load_from_cache

if TYPE_CHECKING:
from boto3.resources.factory.s3 import Bucket


S3_PREFIX = "s3://"


def _get_bucket_and_model_id(model_id: str) -> Tuple[str, str]:
if model_id.startswith(S3_PREFIX):
model_id_no_protocol = model_id[len(S3_PREFIX) :]
if "/" not in model_id_no_protocol:
raise ValueError(
f"Invalid model_id {model_id}. "
f"model_id should be of the form `s3://bucket_name/model_id`"
)
bucket_name, model_id = model_id_no_protocol.split("/", 1)
return bucket_name, model_id

bucket = os.getenv("PREDIBASE_MODEL_BUCKET")
if not bucket:
# assume that the id preceding the first slash is the bucket name
if "/" not in model_id:
raise ValueError(
f"Invalid model_id {model_id}. "
f"model_id should be of the form `bucket_name/model_id` "
f"if PREDIBASE_MODEL_BUCKET environment variable is not set"
)

bucket_name, model_id = model_id.split("/", 1)
return bucket_name, model_id

return bucket, model_id


def _get_bucket_resource():
def _get_bucket_resource(bucket_name: str) -> "Bucket":
"""Get the s3 client"""
config = Config(
retries=dict(
Expand All @@ -27,10 +59,7 @@ def _get_bucket_resource():
)
)
s3 = boto3.resource('s3', config=config)
bucket = os.getenv("PREDIBASE_MODEL_BUCKET")
if not bucket:
raise ValueError("PREDIBASE_MODEL_BUCKET environment variable is not set")
return s3.Bucket(bucket)
return s3.Bucket(bucket_name)


def get_s3_model_local_dir(model_id: str):
Expand Down Expand Up @@ -172,10 +201,11 @@ def __init__(self, model_id: str, revision: Optional[str] = "", extension: str =
raise ValueError(f"model_id '{model_id}' is too short for prefix filtering")

# TODO: add support for revisions of the same model
bucket, model_id = _get_bucket_and_model_id(model_id)
self.model_id = model_id
self.revision = revision
self.extension = extension
self.bucket = _get_bucket_resource()
self.bucket = _get_bucket_resource(bucket)

def remote_weight_files(self, extension: str = None):
extension = extension or self.extension
Expand Down
47 changes: 47 additions & 0 deletions server/tests/utils/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import contextlib
import os
from typing import Optional

import pytest

from lorax_server.utils.sources.s3 import _get_bucket_and_model_id


@contextlib.contextmanager
def with_env_var(key: str, value: Optional[str]):
if value is None:
yield
return

prev = os.environ.get(key)
try:
os.environ[key] = value
yield
finally:
if prev is None:
del os.environ[key]
else:
os.environ[key] = prev


@pytest.mark.parametrize(
"s3_path, env_var, expected_bucket, expected_model_id",
[
("s3://loras/foobar", None, "loras", "foobar"),
("s3://loras/foo/bar", None, "loras", "foo/bar"),
("s3://loras/foo/bar", "bucket", "loras", "foo/bar"),
("loras/foobar", None, "loras", "foobar"),
("loras/foo/bar", None, "loras", "foo/bar"),
("loras/foo/bar", "bucket", "bucket", "loras/foo/bar"),
]
)
def test_get_bucket_and_model_id(
s3_path: str,
env_var: Optional[str],
expected_bucket: str,
expected_model_id: str,
):
with with_env_var("PREDIBASE_MODEL_BUCKET", env_var):
bucket, model_id = _get_bucket_and_model_id(s3_path)
assert bucket == expected_bucket
assert model_id == expected_model_id

0 comments on commit 57d5470

Please sign in to comment.