Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Currently, we support the following providers:
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
- [Clarifai](http://clarifai.com)
- [CometAPI](https://www.cometapi.com/?utm_source=huggingface&utm_campaign=integration&utm_medium=integration&utm_content=integration)
- [Together](https://together.xyz)
- [Baseten](https://baseten.co)
- [Blackforestlabs](https://blackforestlabs.ai)
Expand Down Expand Up @@ -98,6 +99,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
- [CometAPI supported models](https://huggingface.co/api/partners/cometapi/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Baseten supported models](https://huggingface.co/api/partners/baseten/models)
- [Clarifai supported models](https://huggingface.co/api/partners/clarifai/models)
Expand Down
6 changes: 6 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as Clarifai from "../providers/clarifai.js";
import * as BlackForestLabs from "../providers/black-forest-labs.js";
import * as Cerebras from "../providers/cerebras.js";
import * as Cohere from "../providers/cohere.js";
import * as CometAPI from "../providers/cometapi.js";
import * as FalAI from "../providers/fal-ai.js";
import * as FeatherlessAI from "../providers/featherless-ai.js";
import * as Fireworks from "../providers/fireworks-ai.js";
Expand Down Expand Up @@ -72,6 +73,11 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
cohere: {
conversational: new Cohere.CohereConversationalTask(),
},
cometapi: {
conversational: new CometAPI.CometAPIConversationalTask(),
"text-generation": new CometAPI.CometAPITextGenerationTask(),
"feature-extraction": new CometAPI.CometAPIFeatureExtractionTask(),
},
"fal-ai": {
"text-to-image": new FalAI.FalAITextToImageTask(),
"text-to-speech": new FalAI.FalAITextToSpeechTask(),
Expand Down
95 changes: 95 additions & 0 deletions packages/inference/src/providers/cometapi.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/**
* See the registered mapping of HF model ID => CometAPI model ID here:
*
* https://huggingface.co/api/partners/cometapi/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at CometAPI and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to CometAPI, please open an issue on the present repo
* and we will tag CometAPI team members.
*
* Thanks!
*/
import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks";
import type { BodyParams } from "../types.js";
import { InferenceClientProviderOutputError } from "../errors.js";

import type { FeatureExtractionTaskHelper } from "./providerHelper.js";
import { BaseConversationalTask, TaskProviderHelper, BaseTextGenerationTask } from "./providerHelper.js";

const COMETAPI_API_BASE_URL = "https://api.cometapi.com/v1";

interface CometAPIEmbeddingsResponse {
data: Array<{
embedding: number[];
}>;
}

export class CometAPIConversationalTask extends BaseConversationalTask {
constructor() {
super("cometapi", COMETAPI_API_BASE_URL);
}
}

export class CometAPITextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("cometapi", COMETAPI_API_BASE_URL);
}

override preparePayload(params: BodyParams): Record<string, unknown> {
return {
model: params.model,
...params.args,
prompt: params.args.inputs,
};
}

override async getResponse(response: unknown): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
response !== null &&
"choices" in response &&
Array.isArray(response.choices) &&
response.choices.length > 0
) {
const completion: unknown = response.choices[0];
if (
typeof completion === "object" &&
!!completion &&
"text" in completion &&
completion.text &&
typeof completion.text === "string"
) {
return {
generated_text: completion.text,
};
}
}
throw new InferenceClientProviderOutputError("Received malformed response from CometAPI text generation API");
}
}

export class CometAPIFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
constructor() {
super("cometapi", COMETAPI_API_BASE_URL);
}

preparePayload(params: BodyParams): Record<string, unknown> {
return {
input: params.args.inputs,
model: params.model,
};
}

makeRoute(): string {
return "v1/embeddings";
}

async getResponse(response: CometAPIEmbeddingsResponse): Promise<FeatureExtractionOutput> {
return response.data.map((item) => item.embedding);
}
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
cerebras: {},
clarifai: {},
cohere: {},
cometapi: {},
"fal-ai": {},
"featherless-ai": {},
"fireworks-ai": {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export const INFERENCE_PROVIDERS = [
"cerebras",
"clarifai",
"cohere",
"cometapi",
"fal-ai",
"featherless-ai",
"fireworks-ai",
Expand Down
78 changes: 78 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,84 @@ describe.skip("InferenceClient", () => {
TIMEOUT
);

describe.concurrent(
"CometAPI",
() => {
const client = new InferenceClient(env.COMETAPI_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING.cometapi = {
"openai/gpt-4o-mini": {
provider: "cometapi",
hfModelId: "openai/gpt-4o-mini",
providerId: "gpt-4o-mini",
status: "live",
task: "conversational",
},
"openai/text-embedding-3-small": {
provider: "cometapi",
hfModelId: "openai/text-embedding-3-small",
providerId: "text-embedding-3-small",
task: "feature-extraction",
status: "live",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "openai/gpt-4o-mini",
provider: "cometapi",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
tool_choice: "none",
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toMatch(/(to )?(two|2)/i);
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "openai/gpt-4o-mini",
provider: "cometapi",
messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toMatch(/(two|2)/i);
});

it("textGeneration", async () => {
const res = await client.textGeneration({
model: "openai/gpt-4o-mini",
provider: "cometapi",
inputs: "Once upon a time,",
temperature: 0,
max_tokens: 20,
});

expect(res).toHaveProperty("generated_text");
expect(typeof res.generated_text).toBe("string");
expect(res.generated_text.length).toBeGreaterThan(0);
});

it("featureExtraction", async () => {
const res = await client.featureExtraction({
model: "openai/text-embedding-3-small",
provider: "cometapi",
inputs: "That is a happy person",
});

expect(res).toBeInstanceOf(Array);
expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)]));
});
},
TIMEOUT
);

describe.concurrent("3rd party providers", () => {
it("chatCompletion - fails with unsupported model", async () => {
expect(
Expand Down