Skip to content

Commit

Permalink
Add AI LLM endpoint (#2319)
Browse files Browse the repository at this point in the history
* Add AI LLM endpoint

* Fix type generation

* fix validator lookup and test
  • Loading branch information
mjh1 authored Oct 7, 2024
1 parent af93493 commit d25ae47
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 5 deletions.
25 changes: 22 additions & 3 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ describe("controllers/generate", () => {
"image-to-video",
"upscale",
"segment-anything-2",
"llm",
];
for (const api of apis) {
aiGatewayServer.app.post(`/${api}`, async (req, res) => {
Expand Down Expand Up @@ -145,13 +146,18 @@ describe("controllers/generate", () => {
textFields: Record<string, any>,
multipartField = { name: "image", contentType: "image/png" },
) => {
const form = buildForm(textFields);
form.append(multipartField.name, "dummy", {
contentType: multipartField.contentType,
});
return form;
};

const buildForm = (textFields: Record<string, any>) => {
const form = new FormData();
for (const [k, v] of Object.entries(textFields)) {
form.append(k, v);
}
form.append(multipartField.name, "dummy", {
contentType: multipartField.contentType,
});
return form;
};

Expand Down Expand Up @@ -237,6 +243,19 @@ describe("controllers/generate", () => {
});
expect(aiGatewayCalls).toEqual({ "segment-anything-2": 1 });
});

it("should call the AI Gateway for generate API /llm", async () => {
const res = await client.fetch("/beta/generate/llm", {
method: "POST",
body: buildForm({ prompt: "foo" }),
});
expect(res.status).toBe(200);
expect(await res.json()).toEqual({
message: "success",
reqContentType: expect.stringMatching("^multipart/form-data"),
});
expect(aiGatewayCalls).toEqual({ llm: 1 });
});
});

describe("validates multipart schema", () => {
Expand Down
5 changes: 5 additions & 0 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { BadRequestError } from "../store/errors";
import { fetchWithTimeout, kebabToCamel } from "../util";
import { experimentSubjectsOnly } from "./experiment";
import { pathJoin2 } from "./helpers";
import validators from "../schema/validators";

const AI_GATEWAY_TIMEOUT = 10 * 60 * 1000; // 10 minutes
const RATE_LIMIT_WINDOW = 60 * 1000; // 1 minute
Expand Down Expand Up @@ -179,6 +180,9 @@ function registerGenerateHandler(
if (isJSONReq) {
payloadParsers = [validatePost(`${camelType}Params`)];
} else {
if (!validators[`Body_gen${camelType}`]) {
camelType = type.toUpperCase();
}
payloadParsers = [
multipart.any(),
validateFormData(`Body_gen${camelType}`),
Expand Down Expand Up @@ -252,5 +256,6 @@ registerGenerateHandler("image-to-video");
registerGenerateHandler("upscale");
registerGenerateHandler("audio-to-text");
registerGenerateHandler("segment-anything-2");
registerGenerateHandler("llm");

export default app;
3 changes: 1 addition & 2 deletions packages/api/src/schema/ai-api-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ components:
model_id:
type: string
title: Model Id
default: ''
default: meta-llama/Meta-Llama-3.1-8B-Instruct
system_msg:
type: string
title: System Msg
Expand All @@ -632,7 +632,6 @@ components:
type: object
required:
- prompt
- model_id
title: Body_genLLM
additionalProperties: false
Body_genSegmentAnything2:
Expand Down
4 changes: 4 additions & 0 deletions packages/api/src/schema/compile-schemas.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ function removeAllTitles(schema) {
}
}

if (schema.oneOf && Array.isArray(schema.oneOf)) {
schema.oneOf = schema.oneOf.map((item) => removeAllTitles(item));
}

return schema;
}

Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/schema/db-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1454,13 +1454,15 @@ components:
- image-to-video
- upscale
- segment-anything-2
- llm
request:
oneOf:
- $ref: "./ai-api-schema.yaml#/components/schemas/TextToImageParams"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToImage"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToVideo"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genUpscale"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genSegmentAnything2"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genLLM"
statusCode:
type: integer
description: HTTP status code received from the AI gateway
Expand Down
1 change: 1 addition & 0 deletions packages/api/src/schema/pull-ai-schema.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export const defaultModels = {
upscale: "stabilityai/stable-diffusion-x4-upscaler",
"audio-to-text": "openai/whisper-large-v3",
"segment-anything-2": "facebook/sam2-hiera-large",
llm: "meta-llama/Meta-Llama-3.1-8B-Instruct",
};
const schemaDir = path.resolve(__dirname, ".");
const aiSchemaUrl =
Expand Down

0 comments on commit d25ae47

Please sign in to comment.