Skip to content

Commit 027c4d2

Browse files
authored
[InferenceClient] Server-side auto-routing for conversational task (#1810)
Equivalent Python PR: huggingface/huggingface_hub#3448 Discussed in [private DMs](https://huggingface.slack.com/archives/C07KX53FZTK/p1759998694399239). Now that we have server-side routing on `https://router.huggingface.co/v1/chat/completions`, it's best to use it in the JS client (centralized logic between JS and Python clients + saves 1 HTTP call). We still keep client-side routing for all other tasks.
1 parent 1179960 commit 027c4d2

File tree

4 files changed

+43
-4
lines changed

4 files changed

+43
-4
lines changed

packages/inference/src/errors.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ export class InferenceClientInputError extends InferenceClientError {
1717
}
1818
}
1919

20+
export class InferenceClientRoutingError extends InferenceClientError {
21+
constructor(message: string) {
22+
super(message);
23+
this.name = "RoutingError";
24+
}
25+
}
26+
2027
interface HttpRequest {
2128
url: string;
2229
method: string;

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ export async function getInferenceProviderMapping(
124124
}
125125
): Promise<InferenceProviderMappingEntry | null> {
126126
const logger = getLogger();
127+
if (params.provider === ("auto" as InferenceProvider) && params.task === "conversational") {
128+
// Special case for auto + conversational to avoid extra API calls
129+
// Call directly the server-side auto router
130+
return {
131+
hfModelId: params.modelId,
132+
provider: "auto",
133+
providerId: params.modelId,
134+
status: "live",
135+
task: "conversational",
136+
};
137+
}
127138
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
128139
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
129140
}

packages/inference/src/providers/providerHelper.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import type {
4747
ZeroShotImageClassificationOutput,
4848
} from "@huggingface/tasks";
4949
import { HF_ROUTER_URL } from "../config.js";
50-
import { InferenceClientProviderOutputError } from "../errors.js";
50+
import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js";
5151
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js";
5252
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js";
5353
import { toArray } from "../utils/toArray.js";
@@ -62,7 +62,7 @@ import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
6262
export abstract class TaskProviderHelper {
6363
constructor(
6464
readonly provider: InferenceProvider,
65-
private baseUrl: string,
65+
protected baseUrl: string,
6666
readonly clientSideRoutingOnly: boolean = false
6767
) {}
6868

@@ -369,3 +369,16 @@ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGe
369369
throw new InferenceClientProviderOutputError("Expected Array<{generated_text: string}>");
370370
}
371371
}
372+
373+
export class AutoRouterConversationalTask extends BaseConversationalTask {
374+
constructor() {
375+
super("auto" as InferenceProvider, "https://router.huggingface.co");
376+
}
377+
378+
override makeBaseUrl(params: UrlParams): string {
379+
if (params.authMethod !== "hf-token") {
380+
throw new InferenceClientRoutingError("Cannot select auto-router when using non-Hugging Face API key.");
381+
}
382+
return this.baseUrl;
383+
}
384+
}

packages/inference/src/tasks/nlp/chatCompletion.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
33
import { getProviderHelper } from "../../lib/getProviderHelper.js";
44
import type { BaseArgs, Options } from "../../types.js";
55
import { innerRequest } from "../../utils/request.js";
6+
import type { ConversationalTaskHelper, TaskProviderHelper } from "../../providers/providerHelper.js";
7+
import { AutoRouterConversationalTask } from "../../providers/providerHelper.js";
68

79
/**
810
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
@@ -11,8 +13,14 @@ export async function chatCompletion(
1113
args: BaseArgs & ChatCompletionInput,
1214
options?: Options
1315
): Promise<ChatCompletionOutput> {
14-
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
15-
const providerHelper = getProviderHelper(provider, "conversational");
16+
let providerHelper: ConversationalTaskHelper & TaskProviderHelper;
17+
if (!args.provider || args.provider === "auto") {
18+
// Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
19+
providerHelper = new AutoRouterConversationalTask();
20+
} else {
21+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
22+
providerHelper = getProviderHelper(provider, "conversational");
23+
}
1624
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
1725
...options,
1826
task: "conversational",

0 commit comments

Comments
 (0)