diff --git a/packages/inference/README.md b/packages/inference/README.md index db3f64b35f..98208262c9 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -61,6 +61,7 @@ Currently, we support the following providers: - [Sambanova](https://sambanova.ai) - [Scaleway](https://www.scaleway.com/en/generative-apis/) - [Clarifai](http://clarifai.com) +- [CometAPI](https://www.cometapi.com/?utm_source=huggingface&utm_campaign=integration&utm_medium=integration&utm_content=integration) - [Together](https://together.xyz) - [Baseten](https://baseten.co) - [Blackforestlabs](https://blackforestlabs.ai) @@ -98,6 +99,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Replicate supported models](https://huggingface.co/api/partners/replicate/models) - [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models) - [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models) +- [CometAPI supported models](https://huggingface.co/api/partners/cometapi/models) - [Together supported models](https://huggingface.co/api/partners/together/models) - [Baseten supported models](https://huggingface.co/api/partners/baseten/models) - [Clarifai supported models](https://huggingface.co/api/partners/clarifai/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index ee39f50342..3d70686052 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -3,6 +3,7 @@ import * as Clarifai from "../providers/clarifai.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; import * as Cerebras from "../providers/cerebras.js"; import * as Cohere from "../providers/cohere.js"; +import * as CometAPI from "../providers/cometapi.js"; import * as FalAI from "../providers/fal-ai.js"; import * as FeatherlessAI from "../providers/featherless-ai.js"; import * as Fireworks from "../providers/fireworks-ai.js"; @@ -72,6 +73,11 @@ export const PROVIDERS: Record CometAPI model ID here: + * + * https://huggingface.co/api/partners/cometapi/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at CometAPI and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to CometAPI, please open an issue on the present repo + * and we will tag CometAPI team members. + * + * Thanks! + */ +import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks"; +import type { BodyParams } from "../types.js"; +import { InferenceClientProviderOutputError } from "../errors.js"; + +import type { FeatureExtractionTaskHelper } from "./providerHelper.js"; +import { BaseConversationalTask, TaskProviderHelper, BaseTextGenerationTask } from "./providerHelper.js"; + +const COMETAPI_API_BASE_URL = "https://api.cometapi.com/v1"; + +interface CometAPIEmbeddingsResponse { + data: Array<{ + embedding: number[]; + }>; +} + +export class CometAPIConversationalTask extends BaseConversationalTask { + constructor() { + super("cometapi", COMETAPI_API_BASE_URL); + } +} + +export class CometAPITextGenerationTask extends BaseTextGenerationTask { + constructor() { + super("cometapi", COMETAPI_API_BASE_URL); + } + + override preparePayload(params: BodyParams): Record { + return { + model: params.model, + ...params.args, + prompt: params.args.inputs, + }; + } + + override async getResponse(response: unknown): Promise { + if ( + typeof response === "object" && + response !== null && + "choices" in response && + Array.isArray(response.choices) && + response.choices.length > 0 + ) { + const completion: unknown = response.choices[0]; + if ( + typeof completion === "object" && + !!completion && + "text" in completion && + completion.text && + typeof completion.text === "string" + ) { + return { + generated_text: completion.text, + }; + } + } + throw new InferenceClientProviderOutputError("Received malformed response from CometAPI text generation API"); + } +} + +export class CometAPIFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper { + constructor() { + super("cometapi", COMETAPI_API_BASE_URL); + } + + preparePayload(params: BodyParams): Record { + return { + input: params.args.inputs, + model: params.model, + }; + } + + makeRoute(): string { + return "v1/embeddings"; + } + + async getResponse(response: CometAPIEmbeddingsResponse): Promise { + return response.data.map((item) => item.embedding); + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index f93d890535..620280e5a6 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -23,6 +23,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< cerebras: {}, clarifai: {}, cohere: {}, + cometapi: {}, "fal-ai": {}, "featherless-ai": {}, "fireworks-ai": {}, diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 215252efda..f476a86581 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -50,6 +50,7 @@ export const INFERENCE_PROVIDERS = [ "cerebras", "clarifai", "cohere", + "cometapi", "fal-ai", "featherless-ai", "fireworks-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 3ea6b55241..be87a5245d 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1626,6 +1626,84 @@ describe.skip("InferenceClient", () => { TIMEOUT ); + describe.concurrent( + "CometAPI", + () => { + const client = new InferenceClient(env.COMETAPI_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING.cometapi = { + "openai/gpt-4o-mini": { + provider: "cometapi", + hfModelId: "openai/gpt-4o-mini", + providerId: "gpt-4o-mini", + status: "live", + task: "conversational", + }, + "openai/text-embedding-3-small": { + provider: "cometapi", + hfModelId: "openai/text-embedding-3-small", + providerId: "text-embedding-3-small", + task: "feature-extraction", + status: "live", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "openai/gpt-4o-mini", + provider: "cometapi", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + tool_choice: "none", + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toMatch(/(to )?(two|2)/i); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "openai/gpt-4o-mini", + provider: "cometapi", + messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toMatch(/(two|2)/i); + }); + + it("textGeneration", async () => { + const res = await client.textGeneration({ + model: "openai/gpt-4o-mini", + provider: "cometapi", + inputs: "Once upon a time,", + temperature: 0, + max_tokens: 20, + }); + + expect(res).toHaveProperty("generated_text"); + expect(typeof res.generated_text).toBe("string"); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + + it("featureExtraction", async () => { + const res = await client.featureExtraction({ + model: "openai/text-embedding-3-small", + provider: "cometapi", + inputs: "That is a happy person", + }); + + expect(res).toBeInstanceOf(Array); + expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)])); + }); + }, + TIMEOUT + ); + describe.concurrent("3rd party providers", () => { it("chatCompletion - fails with unsupported model", async () => { expect(