Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: Update Predibase integration to support v2 API #403

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions server/lorax_server/utils/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
LOCAL = "local"
PBASE = "pbase"

PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}"
PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}"
LEGACY_PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}"
LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}"
PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT = "v2/repos/{}/version/{}"
PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com")


Expand All @@ -30,19 +31,42 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str:
raise ValueError("api_token must be provided to for a model of source pbase")
headers = {"Authorization": f"Bearer {api_token}"}
name_components = model_id.split("/")
# version is optional

url = None
legacy_url = None

if len(name_components) == 1:
name = name_components[0]
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name)
legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_URL_ENDPOINT.format(name)
elif len(name_components) == 2:
name, version = name_components
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version)
url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT.format(name, version)
legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version)
else:
raise ValueError(f"Invalid model id {model_id}")
resp = requests.get(url, headers=headers)
resp.raise_for_status()
uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"]
return f"{uuid}/{best_run_id}/artifacts/model/model_weights/"

def fetch_legacy_url():
r = requests.get(legacy_url, headers=headers)
r.raise_for_status()
uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"]
return f"{uuid}/{best_run_id}/artifacts/model/model_weights/"

if url is not None:
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to think about this in terms of customer impact: we really only expect staging to use this for the next few days until we migrate all of LLM Models to the Adapters table and a very brief window in Prod while we migrate the data right?

# Try to retrieve data using the new endpoint.
resp = requests.get(url, headers=headers)
resp.raise_for_status()
except requests.RequestException:
# Not found in new path, fall back to legacy endpoint.
return fetch_legacy_url()

path = resp.json().get("adapterPath", None)
if path is None:
raise RuntimeError(f"Adapter {model_id} is not yet available")
return path
else:
# Use legacy path only since new endpoint requires both name and version number.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this code is identical to 54:58. Can we pull it into a function?

def fetch_legacy_url(legacy_url):
    resp = requests.get(legacy_url, headers=headers)
    resp.raise_for_status()
    uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"]
    return f"{uuid}/{best_run_id}/artifacts/model/model_weights/"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return fetch_legacy_url()


# TODO(travis): refactor into registry pattern
Expand Down
Loading