From 345291632cf2f934a910a2beea5f4b5c6d210155 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 24 Apr 2025 00:43:03 -0400 Subject: [PATCH 01/10] add centml --- packages/inference/src/providers/centml.ts | 66 ++++++++++++++++++++++ packages/inference/src/types.ts | 1 + 2 files changed, 67 insertions(+) create mode 100644 packages/inference/src/providers/centml.ts diff --git a/packages/inference/src/providers/centml.ts b/packages/inference/src/providers/centml.ts new file mode 100644 index 0000000000..9a8f5de386 --- /dev/null +++ b/packages/inference/src/providers/centml.ts @@ -0,0 +1,66 @@ +/** + * CentML provider implementation for serverless inference. + * This provider supports chat completions and text generation through CentML's serverless endpoints. + */ +import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks"; +import { InferenceOutputError } from "../lib/InferenceOutputError"; +import type { BodyParams } from "../types"; +import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; + +const CENTML_API_BASE_URL = "https://api.centml.ai"; + +export class CentMLConversationalTask extends BaseConversationalTask { + constructor() { + super("centml", CENTML_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + const { args, model } = params; + return { + ...args, + model, + api_key: args.accessToken, // Use the accessToken from args + }; + } + + override async getResponse(response: ChatCompletionOutput): Promise { + if ( + typeof response === "object" && + Array.isArray(response?.choices) && + typeof response?.created === "number" && + typeof response?.id === "string" && + typeof response?.model === "string" && + typeof response?.usage === "object" + ) { + return response; + } + + throw new InferenceOutputError("Expected ChatCompletionOutput"); + } +} + +export class CentMLTextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("centml", CENTML_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + const { args, model } = params; + return { + ...args, + model, + api_key: args.accessToken, // Use the accessToken from args + }; + } + + override async getResponse(response: TextGenerationOutput): Promise { + if ( + typeof response === "object" && + typeof response?.generated_text === "string" + ) { + return response; + } + + throw new InferenceOutputError("Expected TextGenerationOutput"); + } +} \ No newline at end of file diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index e5870f6ef3..9714baf08e 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -40,6 +40,7 @@ export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ "black-forest-labs", "cerebras", + "centml", "cohere", "fal-ai", "featherless-ai", From f4ded013bfb5e53b7300be6d832dfaa9eaf5db61 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 24 Apr 2025 22:52:52 -0600 Subject: [PATCH 02/10] add README and populate getProviderHelper.ts --- packages/inference/README.md | 2 ++ packages/inference/src/lib/getProviderHelper.ts | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/packages/inference/README.md b/packages/inference/README.md index ad5c9fc1a6..f86f56fba8 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -60,6 +60,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [CentML](https://centml.ai) - [Groq](https://groq.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. @@ -89,6 +90,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [CentML supported models](https://huggingface.co/api/partners/centml/models) - [Groq supported models](https://console.groq.com/docs/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 14bd941987..8ec77a6c69 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,5 +1,6 @@ import * as BlackForestLabs from "../providers/black-forest-labs"; import * as Cerebras from "../providers/cerebras"; +import * as CentML from "../providers/centml"; import * as Cohere from "../providers/cohere"; import * as FalAI from "../providers/fal-ai"; import * as FeatherlessAI from "../providers/featherless-ai"; @@ -55,6 +56,10 @@ export const PROVIDERS: Record Date: Fri, 25 Apr 2025 21:48:35 -0600 Subject: [PATCH 03/10] fix typo / formatting --- packages/inference/src/providers/centml.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/inference/src/providers/centml.ts b/packages/inference/src/providers/centml.ts index 9a8f5de386..ef736f3755 100644 --- a/packages/inference/src/providers/centml.ts +++ b/packages/inference/src/providers/centml.ts @@ -7,7 +7,7 @@ import { InferenceOutputError } from "../lib/InferenceOutputError"; import type { BodyParams } from "../types"; import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; -const CENTML_API_BASE_URL = "https://api.centml.ai"; +const CENTML_API_BASE_URL = "https://api.centml.com"; export class CentMLConversationalTask extends BaseConversationalTask { constructor() { @@ -63,4 +63,4 @@ export class CentMLTextGenerationTask extends BaseTextGenerationTask { throw new InferenceOutputError("Expected TextGenerationOutput"); } -} \ No newline at end of file +} From 0a68414719a365e9382db00edc3fc7345d27c0f7 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Fri, 25 Apr 2025 23:08:50 -0600 Subject: [PATCH 04/10] fix centml ls --- packages/inference/src/providers/centml.ts | 93 ++++++++++++---------- packages/inference/src/providers/consts.ts | 8 ++ 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/packages/inference/src/providers/centml.ts b/packages/inference/src/providers/centml.ts index ef736f3755..9a135b942b 100644 --- a/packages/inference/src/providers/centml.ts +++ b/packages/inference/src/providers/centml.ts @@ -10,57 +10,62 @@ import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper const CENTML_API_BASE_URL = "https://api.centml.com"; export class CentMLConversationalTask extends BaseConversationalTask { - constructor() { - super("centml", CENTML_API_BASE_URL); - } + constructor() { + super("centml", CENTML_API_BASE_URL); + } - override preparePayload(params: BodyParams): Record { - const { args, model } = params; - return { - ...args, - model, - api_key: args.accessToken, // Use the accessToken from args - }; - } + override makeRoute(): string { + return "openai/v1/chat/completions"; + } - override async getResponse(response: ChatCompletionOutput): Promise { - if ( - typeof response === "object" && - Array.isArray(response?.choices) && - typeof response?.created === "number" && - typeof response?.id === "string" && - typeof response?.model === "string" && - typeof response?.usage === "object" - ) { - return response; - } + override preparePayload(params: BodyParams): Record { + const { args, model } = params; + return { + ...args, + model, + api_key: args.accessToken, + }; + } - throw new InferenceOutputError("Expected ChatCompletionOutput"); - } + override async getResponse(response: ChatCompletionOutput): Promise { + if ( + typeof response === "object" && + Array.isArray(response?.choices) && + typeof response?.created === "number" && + typeof response?.id === "string" && + typeof response?.model === "string" && + typeof response?.usage === "object" + ) { + return response; + } + + throw new InferenceOutputError("Expected ChatCompletionOutput"); + } } export class CentMLTextGenerationTask extends BaseTextGenerationTask { - constructor() { - super("centml", CENTML_API_BASE_URL); - } + constructor() { + super("centml", CENTML_API_BASE_URL); + } + + override makeRoute(): string { + return "openai/v1/completions"; + } - override preparePayload(params: BodyParams): Record { - const { args, model } = params; - return { - ...args, - model, - api_key: args.accessToken, // Use the accessToken from args - }; - } + override preparePayload(params: BodyParams): Record { + const { args, model } = params; + return { + ...args, + model, + api_key: args.accessToken, + }; + } - override async getResponse(response: TextGenerationOutput): Promise { - if ( - typeof response === "object" && - typeof response?.generated_text === "string" - ) { - return response; - } + override async getResponse(response: TextGenerationOutput): Promise { + if (typeof response === "object" && typeof response?.generated_text === "string") { + return response; + } - throw new InferenceOutputError("Expected TextGenerationOutput"); - } + throw new InferenceOutputError("Expected TextGenerationOutput"); + } } diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 50253e2b4e..8417c6473c 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -21,6 +21,14 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< */ "black-forest-labs": {}, cerebras: {}, + centml: { + "meta-llama/Llama-3.2-3B-Instruct": { + hfModelId: "meta-llama/Llama-3.2-3B-Instruct", + providerId: "meta-llama/Llama-3.2-3B-Instruct", // CentML expects same id + status: "live", // or "staging" if you prefer the warning + task: "conversational" // <-- WidgetType from @huggingface/tasks + } + }, cohere: {}, "fal-ai": {}, "featherless-ai": {}, From 7f121fee1a2186124ac9d489d51226c47d1af71c Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 29 Apr 2025 00:08:28 -0600 Subject: [PATCH 05/10] remove from const.ts --- packages/inference/src/providers/consts.ts | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 8417c6473c..0dc2afc70e 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -21,14 +21,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< */ "black-forest-labs": {}, cerebras: {}, - centml: { - "meta-llama/Llama-3.2-3B-Instruct": { - hfModelId: "meta-llama/Llama-3.2-3B-Instruct", - providerId: "meta-llama/Llama-3.2-3B-Instruct", // CentML expects same id - status: "live", // or "staging" if you prefer the warning - task: "conversational" // <-- WidgetType from @huggingface/tasks - } - }, + centml: {}, cohere: {}, "fal-ai": {}, "featherless-ai": {}, From e99c0c96e400931be38d88aa9719638f84174879 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 29 Apr 2025 00:52:56 -0600 Subject: [PATCH 06/10] add centml tests, enable tests --- .../inference/test/InferenceClient.spec.ts | 143 +++++++++++++++++- 1 file changed, 142 insertions(+), 1 deletion(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index c64a396d37..0849810bb0 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -22,7 +22,7 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe.skip("InferenceClient", () => { +describe("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. describe("backward compatibility", () => { @@ -1875,4 +1875,145 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + + describe.concurrent( + "CentML", + () => { + const client = new InferenceClient(env.HF_CENTML_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["centml"] = { + "meta-llama/Llama-3.2-3B-Instruct": { + hfModelId: "meta-llama/Llama-3.2-3B-Instruct", + providerId: "meta-llama/Llama-3.2-3B-Instruct", + status: "live", + task: "conversational", + }, + "meta-llama/Llama-3.2-3B": { + hfModelId: "meta-llama/Llama-3.2-3B", + providerId: "meta-llama/Llama-3.2-3B", + status: "live", + task: "text-generation", + }, + }; + + describe("chat completions", () => { + it("basic chat completion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chat completion with multiple messages", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is 2+2?" }, + { role: "assistant", content: "The answer is 4." }, + { role: "user", content: "What is 3+3?" }, + ], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("6"); + } + }); + + it("chat completion with parameters", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Write a short poem about AI" }], + temperature: 0.7, + max_tokens: 100, + top_p: 0.9, + }); + if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) { + const completion = res.choices[0].message.content; + expect(completion).toBeTruthy(); + expect(completion.length).toBeGreaterThan(0); + } + }); + + it("chat completion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }); + + describe("text generation", () => { + it("basic text generation", async () => { + const res = await client.textGeneration({ + model: "meta-llama/Llama-3.2-3B", + provider: "centml", + inputs: "The capital of France is", + }); + expect(res).toMatchObject({ + generated_text: expect.stringContaining("Paris"), + }); + }); + + it("text generation with parameters", async () => { + const res = await client.textGeneration({ + model: "meta-llama/Llama-3.2-3B", + provider: "centml", + inputs: "Once upon a time", + parameters: { + max_new_tokens: 50, + temperature: 0.7, + top_p: 0.9, + do_sample: true, + }, + }); + expect(res).toMatchObject({ + generated_text: expect.any(String), + }); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + + it("text generation stream", async () => { + const stream = client.textGenerationStream({ + model: "meta-llama/Llama-3.2-3B", + provider: "centml", + inputs: "The future of AI is", + }); + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.token?.text) { + fullResponse += chunk.token.text; + } + } + + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }); + }, + TIMEOUT + ); }); From 4942d493a885d0fb61beabaa663c937a198886fc Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 29 Apr 2025 01:07:58 -0600 Subject: [PATCH 07/10] remove some override --- packages/inference/src/providers/centml.ts | 44 ---------------------- 1 file changed, 44 deletions(-) diff --git a/packages/inference/src/providers/centml.ts b/packages/inference/src/providers/centml.ts index 9a135b942b..7735a1980d 100644 --- a/packages/inference/src/providers/centml.ts +++ b/packages/inference/src/providers/centml.ts @@ -2,9 +2,6 @@ * CentML provider implementation for serverless inference. * This provider supports chat completions and text generation through CentML's serverless endpoints. */ -import type { ChatCompletionOutput, TextGenerationOutput } from "@huggingface/tasks"; -import { InferenceOutputError } from "../lib/InferenceOutputError"; -import type { BodyParams } from "../types"; import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; const CENTML_API_BASE_URL = "https://api.centml.com"; @@ -17,30 +14,6 @@ export class CentMLConversationalTask extends BaseConversationalTask { override makeRoute(): string { return "openai/v1/chat/completions"; } - - override preparePayload(params: BodyParams): Record { - const { args, model } = params; - return { - ...args, - model, - api_key: args.accessToken, - }; - } - - override async getResponse(response: ChatCompletionOutput): Promise { - if ( - typeof response === "object" && - Array.isArray(response?.choices) && - typeof response?.created === "number" && - typeof response?.id === "string" && - typeof response?.model === "string" && - typeof response?.usage === "object" - ) { - return response; - } - - throw new InferenceOutputError("Expected ChatCompletionOutput"); - } } export class CentMLTextGenerationTask extends BaseTextGenerationTask { @@ -51,21 +24,4 @@ export class CentMLTextGenerationTask extends BaseTextGenerationTask { override makeRoute(): string { return "openai/v1/completions"; } - - override preparePayload(params: BodyParams): Record { - const { args, model } = params; - return { - ...args, - model, - api_key: args.accessToken, - }; - } - - override async getResponse(response: TextGenerationOutput): Promise { - if (typeof response === "object" && typeof response?.generated_text === "string") { - return response; - } - - throw new InferenceOutputError("Expected TextGenerationOutput"); - } } From ac5b19586bf8da9322b1afca50d134200117d8d2 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 29 Apr 2025 01:21:17 -0600 Subject: [PATCH 08/10] remove textGen api, currently not officially supported via platform --- .../inference/src/lib/getProviderHelper.ts | 1 - packages/inference/src/providers/centml.ts | 10 ---- .../inference/test/InferenceClient.spec.ts | 54 ------------------- 3 files changed, 65 deletions(-) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 8ec77a6c69..835fd9738c 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -58,7 +58,6 @@ export const PROVIDERS: Record { status: "live", task: "conversational", }, - "meta-llama/Llama-3.2-3B": { - hfModelId: "meta-llama/Llama-3.2-3B", - providerId: "meta-llama/Llama-3.2-3B", - status: "live", - task: "text-generation", - }, }; describe("chat completions", () => { @@ -1965,54 +1959,6 @@ describe("InferenceClient", () => { }); }); - describe("text generation", () => { - it("basic text generation", async () => { - const res = await client.textGeneration({ - model: "meta-llama/Llama-3.2-3B", - provider: "centml", - inputs: "The capital of France is", - }); - expect(res).toMatchObject({ - generated_text: expect.stringContaining("Paris"), - }); - }); - - it("text generation with parameters", async () => { - const res = await client.textGeneration({ - model: "meta-llama/Llama-3.2-3B", - provider: "centml", - inputs: "Once upon a time", - parameters: { - max_new_tokens: 50, - temperature: 0.7, - top_p: 0.9, - do_sample: true, - }, - }); - expect(res).toMatchObject({ - generated_text: expect.any(String), - }); - expect(res.generated_text.length).toBeGreaterThan(0); - }); - - it("text generation stream", async () => { - const stream = client.textGenerationStream({ - model: "meta-llama/Llama-3.2-3B", - provider: "centml", - inputs: "The future of AI is", - }); - - let fullResponse = ""; - for await (const chunk of stream) { - if (chunk.token?.text) { - fullResponse += chunk.token.text; - } - } - - expect(fullResponse).toBeTruthy(); - expect(fullResponse.length).toBeGreaterThan(0); - }); - }); }, TIMEOUT ); From 06f62e4975a5149201cbc91b77b77326aa82aa03 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 29 Apr 2025 01:22:35 -0600 Subject: [PATCH 09/10] skip tests --- packages/inference/test/InferenceClient.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index fa2570b402..eba56588d9 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -22,7 +22,7 @@ if (!env.HF_TOKEN) { console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); } -describe("InferenceClient", () => { +describe.skip("InferenceClient", () => { // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error. describe("backward compatibility", () => { From 6270e1952b5978850f9410a79cf5a8c60b532624 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 30 Apr 2025 12:21:10 +0200 Subject: [PATCH 10/10] format --- packages/inference/src/providers/centml.ts | 2 +- packages/inference/test/InferenceClient.spec.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/inference/src/providers/centml.ts b/packages/inference/src/providers/centml.ts index e721defc99..f1b7776a71 100644 --- a/packages/inference/src/providers/centml.ts +++ b/packages/inference/src/providers/centml.ts @@ -2,7 +2,7 @@ * CentML provider implementation for serverless inference. * This provider supports chat completions and text generation through CentML's serverless endpoints. */ -import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper"; +import { BaseConversationalTask } from "./providerHelper"; const CENTML_API_BASE_URL = "https://api.centml.com"; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index ec37a98c99..7b4a1035c6 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2063,4 +2063,4 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); -}); \ No newline at end of file +});