Skip to content

Commit

Permalink
enh: Update Predibase integration to support v2 API (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang authored Apr 11, 2024
1 parent 30174d7 commit ce501cd
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions server/lorax_server/utils/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
LOCAL = "local"
PBASE = "pbase"

PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}"
PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}"
LEGACY_PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}"
LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}"
PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT = "v2/repos/{}/version/{}"
PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com")


Expand All @@ -30,19 +31,42 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str:
raise ValueError("api_token must be provided to for a model of source pbase")
headers = {"Authorization": f"Bearer {api_token}"}
name_components = model_id.split("/")
# version is optional

url = None
legacy_url = None

if len(name_components) == 1:
name = name_components[0]
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name)
legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_URL_ENDPOINT.format(name)
elif len(name_components) == 2:
name, version = name_components
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version)
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT.format(name, version)
legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version)
else:
raise ValueError(f"Invalid model id {model_id}")
resp = requests.get(url, headers=headers)
resp.raise_for_status()
uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"]
return f"{uuid}/{best_run_id}/artifacts/model/model_weights/"

def fetch_legacy_url():
r = requests.get(legacy_url, headers=headers)
r.raise_for_status()
uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"]
return f"{uuid}/{best_run_id}/artifacts/model/model_weights/"

if url is not None:
try:
# Try to retrieve data using the new endpoint.
resp = requests.get(url, headers=headers)
resp.raise_for_status()
except requests.RequestException:
# Not found in new path, fall back to legacy endpoint.
return fetch_legacy_url()

path = resp.json().get("adapterPath", None)
if path is None:
raise RuntimeError(f"Adapter {model_id} is not yet available")
return path
else:
# Use legacy path only since new endpoint requires both name and version number.
return fetch_legacy_url()


# TODO(travis): refactor into registry pattern
Expand Down

0 comments on commit ce501cd

Please sign in to comment.