From c04b48667a824d82778bd5de81788b458c2cdb80 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 9 Apr 2024 14:58:57 -0500 Subject: [PATCH 1/5] wip --- server/lorax_server/utils/sources/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index ed852ae98..da0666b2b 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -19,8 +19,9 @@ LOCAL = "local" PBASE = "pbase" -PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}" -PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}" +LEGACY_PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}" +LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}" +PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT = "v2/repos/{}/version/{}" PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com") @@ -33,10 +34,10 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # version is optional if len(name_components) == 1: name = name_components[0] - url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name) + url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_URL_ENDPOINT.format(name) elif len(name_components) == 2: name, version = name_components - url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) + url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) else: raise ValueError(f"Invalid model id {model_id}") resp = requests.get(url, headers=headers) From 8b05c35787ea7ed7c0476d000e679e6f22cfd4cb Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 9 Apr 2024 16:10:41 -0500 Subject: [PATCH 2/5] fallback --- server/lorax_server/utils/sources/__init__.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index da0666b2b..d7317fec7 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -31,19 +31,42 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: raise ValueError("api_token must be provided to for a model of source pbase") headers = {"Authorization": f"Bearer {api_token}"} name_components = model_id.split("/") - # version is optional + + url = None + legacy_url = None + if len(name_components) == 1: name = name_components[0] - url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_URL_ENDPOINT.format(name) + legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_URL_ENDPOINT.format(name) elif len(name_components) == 2: name, version = name_components - url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) + url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_ADAPTER_VERSION_URL_ENDPOINT.format(name, version) + legacy_url = PREDIBASE_GATEWAY_ENDPOINT + LEGACY_PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) else: raise ValueError(f"Invalid model id {model_id}") - resp = requests.get(url, headers=headers) - resp.raise_for_status() - uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] - return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" + + if url is not None: + try: + # Try to retrieve data using the new endpoint. + resp = requests.get(url, headers=headers) + resp.raise_for_status() + except: + # Not found in new path, fall back to legacy endpoint. + resp = requests.get(legacy_url, headers=headers) + resp.raise_for_status() + uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] + return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" + + path = resp.json().get("adapterPath", None) + if path is None: + raise RuntimeError(f"Adapter {model_id} is not yet available") + return path + else: + # Use legacy path only since new endpoint requires both name and version number. + resp = requests.get(legacy_url, headers=headers) + resp.raise_for_status() + uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] + return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" # TODO(travis): refactor into registry pattern From ac171f160fe1923ccef834e01bc5fd7974f291ef Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 9 Apr 2024 16:16:28 -0500 Subject: [PATCH 3/5] noqa --- server/lorax_server/utils/sources/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index d7317fec7..0f3720062 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -50,7 +50,7 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # Try to retrieve data using the new endpoint. resp = requests.get(url, headers=headers) resp.raise_for_status() - except: + except: # ruff: noqa E722 # Not found in new path, fall back to legacy endpoint. resp = requests.get(legacy_url, headers=headers) resp.raise_for_status() From da21bb7d70d8ef137fa7db4637043a48fb2d2d42 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 9 Apr 2024 16:21:08 -0500 Subject: [PATCH 4/5] fix except --- server/lorax_server/utils/sources/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 0f3720062..3e5047bf0 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -50,7 +50,7 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # Try to retrieve data using the new endpoint. resp = requests.get(url, headers=headers) resp.raise_for_status() - except: # ruff: noqa E722 + except requests.RequestException: # Not found in new path, fall back to legacy endpoint. resp = requests.get(legacy_url, headers=headers) resp.raise_for_status() From 01f18758411970eeebb1ee7be621bd3abdebb985 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Wed, 10 Apr 2024 17:22:07 -0500 Subject: [PATCH 5/5] refactor --- server/lorax_server/utils/sources/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 3e5047bf0..937974187 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -45,6 +45,12 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: else: raise ValueError(f"Invalid model id {model_id}") + def fetch_legacy_url(): + r = requests.get(legacy_url, headers=headers) + r.raise_for_status() + uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] + return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" + if url is not None: try: # Try to retrieve data using the new endpoint. @@ -52,10 +58,7 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: resp.raise_for_status() except requests.RequestException: # Not found in new path, fall back to legacy endpoint. - resp = requests.get(legacy_url, headers=headers) - resp.raise_for_status() - uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] - return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" + return fetch_legacy_url() path = resp.json().get("adapterPath", None) if path is None: @@ -63,10 +66,7 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: return path else: # Use legacy path only since new endpoint requires both name and version number. - resp = requests.get(legacy_url, headers=headers) - resp.raise_for_status() - uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] - return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" + return fetch_legacy_url() # TODO(travis): refactor into registry pattern