diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index a938d2ba48..096879e396 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -6,7 +6,7 @@ ) from huggingface_hub.utils import logging -from ._common import TaskProviderHelper, _fetch_inference_provider_mapping +from ._common import AutoRouterConversationalTask, TaskProviderHelper, _fetch_inference_provider_mapping from .black_forest_labs import BlackForestLabsTextToImageTask from .cerebras import CerebrasConversationalTask from .cohere import CohereConversationalTask @@ -71,6 +71,8 @@ PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] +CONVERSATIONAL_AUTO_ROUTER = AutoRouterConversationalTask() + PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = { "black-forest-labs": { "text-to-image": BlackForestLabsTextToImageTask(), @@ -201,13 +203,19 @@ def get_provider_helper( if provider is None: logger.info( - "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." + "No provider specified for task `conversational`. Defaulting to server-side auto routing." + if task == "conversational" + else "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." ) provider = "auto" if provider == "auto": if model is None: raise ValueError("Specifying a model is required when provider is 'auto'") + if task == "conversational": + # Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping. + return CONVERSATIONAL_AUTO_ROUTER + provider_mapping = _fetch_inference_provider_mapping(model) provider = next(iter(provider_mapping)).provider diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index adbe05e611..4804c2ff06 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -278,6 +278,44 @@ def _prepare_payload_as_dict( return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id}) +class AutoRouterConversationalTask(BaseConversationalTask): + """ + Auto-router for conversational tasks. + + We let the Hugging Face router select the best provider for the model, based on availability and user preferences. + This is a special case since the selection is done server-side (avoid 1 API call to fetch provider mapping). + """ + + def __init__(self): + super().__init__(provider="auto", base_url="https://router.huggingface.co") + + def _prepare_base_url(self, api_key: str) -> str: + """Return the base URL to use for the request. + + Usually not overwritten in subclasses.""" + # Route to the proxy if the api_key is a HF TOKEN + if not api_key.startswith("hf_"): + raise ValueError("Cannot select auto-router when using non-Hugging Face API key.") + else: + return self.base_url # No `/auto` suffix in the URL + + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: + """ + In auto-router, we don't need to fetch provider mapping info. + We just return a dummy mapping info with provider_id set to the HF model ID. + """ + if model is None: + raise ValueError("Please provide an HF model ID.") + + return InferenceProviderMapping( + provider="auto", + hf_model_id=model, + providerId=model, + status="live", + task="conversational", + ) + + class BaseTextGenerationTask(TaskProviderHelper): """ Base class for text-generation (completion) tasks. diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 6aa2d23aba..4ca5c9dc14 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -9,6 +9,7 @@ from huggingface_hub.inference._common import RequestParameters from huggingface_hub.inference._providers import PROVIDERS, get_provider_helper from huggingface_hub.inference._providers._common import ( + AutoRouterConversationalTask, BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, @@ -193,6 +194,47 @@ def test_prepare_url(self, mocker): helper._prepare_route.assert_called_once_with("test-model", "sk_test_token") +class TestAutoRouterConversationalTask: + def test_properties(self): + helper = AutoRouterConversationalTask() + assert helper.provider == "auto" + assert helper.base_url == "https://router.huggingface.co" + assert helper.task == "conversational" + + def test_prepare_mapping_info_is_fake(self): + helper = AutoRouterConversationalTask() + mapping_info = helper._prepare_mapping_info("test-model") + assert mapping_info.hf_model_id == "test-model" + assert mapping_info.provider_id == "test-model" + assert mapping_info.task == "conversational" + assert mapping_info.status == "live" + + def test_prepare_request(self): + helper = AutoRouterConversationalTask() + + request = helper.prepare_request( + inputs=[{"role": "user", "content": "Hello!"}], + parameters={"model": "test-model", "frequency_penalty": 1.0}, + headers={}, + model="test-model", + api_key="hf_test_token", + ) + + # Use auto-router URL + assert request.url == "https://router.huggingface.co/v1/chat/completions" + + # The rest is the expected request for a Chat Completion API + assert request.headers["authorization"] == "Bearer hf_test_token" + assert request.json == { + "messages": [{"role": "user", "content": "Hello!"}], + "model": "test-model", + "frequency_penalty": 1.0, + } + assert request.task == "conversational" + assert request.model == "test-model" + assert request.data is None + + class TestBlackForestLabsProvider: def test_prepare_headers_bfl_key(self): helper = BlackForestLabsTextToImageTask() @@ -1670,7 +1712,7 @@ def test_filter_none(data: dict, expected: dict): assert filter_none(data) == expected -def test_get_provider_helper_auto(mocker): +def test_get_provider_helper_auto_non_conversational(mocker): """Test the 'auto' provider selection logic.""" mock_provider_a_helper = mocker.Mock(spec=TaskProviderHelper) @@ -1692,3 +1734,13 @@ def test_get_provider_helper_auto(mocker): PROVIDERS.pop("provider-a", None) PROVIDERS.pop("provider-b", None) + + +def test_get_provider_helper_auto_conversational(): + """Test the 'auto' provider selection logic for conversational task. + + In practice, no HTTP call is made to the Hub because routing is done server-side. + """ + helper = get_provider_helper(provider="auto", task="conversational", model="test-model") + + assert isinstance(helper, AutoRouterConversationalTask)