Skip to content

Commit

Permalink
use env var to override predibase api token when preloading adapters …
Browse files Browse the repository at this point in the history
…during init
  • Loading branch information
noyoshi committed Jul 16, 2024
1 parent 8773123 commit 19590a5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,17 @@ async def serve_inner(
pass

if preloaded_adapter_ids:
_adapter_source = enum_string_to_adapter_source(adapter_source)
adapter_preload_api_token = None
if _adapter_source == generate_pb2.AdapterSource.PBASE:
# Derive the predibase token from an env variable if we are using predibase adapters.
adapter_preload_api_token = os.getenv("PREDIBASE_API_TOKEN")
logger.info(f"Preloading {len(preloaded_adapter_ids)} adapters")
requests = [
generate_pb2.DownloadAdapterRequest(
adapter_parameters=generate_pb2.AdapterParameters(adapter_ids=[adapter_id]),
adapter_source=enum_string_to_adapter_source(adapter_source),
adapter_source=_adapter_source,
api_token=adapter_preload_api_token,
)
for adapter_id in preloaded_adapter_ids
]
Expand All @@ -298,14 +304,14 @@ async def serve_inner(
# TODO(travis): load weights into GPU memory as well
for i, adapter_id in enumerate(preloaded_adapter_ids):
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token=None)
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token=adapter_preload_api_token)
adapter_source = S3

model.load_adapter(
generate_pb2.AdapterParameters(adapter_ids=[adapter_id]),
adapter_source,
adapter_index=i + 1,
api_token=None,
api_token=adapter_preload_api_token,
dynamic=True,
)

Expand Down

0 comments on commit 19590a5

Please sign in to comment.