From 75cd88a75d8331c68cc07d4535fd5fcad98a8819 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sun, 28 Apr 2024 13:17:52 -0700 Subject: [PATCH] fix: Downloading private adapters from HF (#443) --- server/lorax_server/utils/sources/hub.py | 2 +- server/tests/adapters/test_utils.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 server/tests/adapters/test_utils.py diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index d23c123d2..f23ce2f6d 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -187,7 +187,7 @@ def download_model_assets(self): def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]: try: - return Path(hf_hub_download(self.model_id, revision=None, filename=filename)) + return Path(hf_hub_download(self.model_id, revision=None, filename=filename, token=self.api_token)) except Exception as e: if ignore_errors: return None diff --git a/server/tests/adapters/test_utils.py b/server/tests/adapters/test_utils.py new file mode 100644 index 000000000..8a94423e9 --- /dev/null +++ b/server/tests/adapters/test_utils.py @@ -0,0 +1,23 @@ +import os + +import pytest +from huggingface_hub.utils import RepositoryNotFoundError + +from lorax_server.adapters.utils import download_adapter +from lorax_server.utils.sources import HUB + + +def test_download_private_adapter_hf(): + # store and unset HUGGING_FACE_HUB_TOKEN from the environment + token = os.environ.pop("HUGGING_FACE_HUB_TOKEN", None) + assert token is not None, "HUGGING_FACE_HUB_TOKEN must be set in the environment to run this test" + + # verify download fails without the token set + with pytest.raises(RepositoryNotFoundError): + download_adapter("predibase/test-private-lora", HUB, api_token=None) + + # pass in the token and verify download succeeds + download_adapter("predibase/test-private-lora", HUB, api_token=token) + + # set the token back in the environment + os.environ["HUGGING_FACE_HUB_TOKEN"] = token