From af93493f93c7cbca2beb96cd8ae52a26bfa398e0 Mon Sep 17 00:00:00 2001 From: gioelecerati <50955448+gioelecerati@users.noreply.github.com> Date: Thu, 3 Oct 2024 10:15:30 +0200 Subject: [PATCH] ai: generate: remove endpoint from experiment & remove beta from path (#2318) * ai: generate: remove endpoint from experiment & remove beta from path * revert broken schema * remove beta path from schema * generate schema & remove unused middleware --- packages/api/src/controllers/generate.test.ts | 50 ++++--- packages/api/src/controllers/generate.ts | 2 - packages/api/src/controllers/index.ts | 2 + packages/api/src/schema/ai-api-schema.yaml | 137 +++++++++++++++++- packages/api/src/schema/pull-ai-schema.js | 4 +- 5 files changed, 164 insertions(+), 31 deletions(-) diff --git a/packages/api/src/controllers/generate.test.ts b/packages/api/src/controllers/generate.test.ts index 4b2c0f078..c58ca4f00 100644 --- a/packages/api/src/controllers/generate.test.ts +++ b/packages/api/src/controllers/generate.test.ts @@ -71,6 +71,16 @@ afterEach(async () => { await clearDatabase(server); }); +const testBothRoutes = (testFn) => { + describe("generate route", () => { + testFn("/generate"); + }); + + describe("beta generate route", () => { + testFn("/beta/generate"); + }); +}; + describe("controllers/generate", () => { let client: TestClient; let adminUser: User; @@ -145,9 +155,9 @@ describe("controllers/generate", () => { return form; }; - describe("API proxies", () => { - it("should call the AI Gateway for generate API /audio-to-text", async () => { - const res = await client.fetch("/beta/generate/audio-to-text", { + testBothRoutes((basePath) => { + it(`should call the AI Gateway for ${basePath}/audio-to-text`, async () => { + const res = await client.fetch(`${basePath}/audio-to-text`, { method: "POST", body: buildMultipartBody( {}, @@ -162,8 +172,8 @@ describe("controllers/generate", () => { expect(aiGatewayCalls).toEqual({ "audio-to-text": 1 }); }); - it("should call the AI Gateway for generate API /text-to-image", async () => { - const res = await client.post("/beta/generate/text-to-image", { + it(`should call the AI Gateway for ${basePath}/text-to-image`, async () => { + const res = await client.post(`${basePath}/text-to-image`, { prompt: "a man in a suit and tie", }); expect(res.status).toBe(200); @@ -174,8 +184,8 @@ describe("controllers/generate", () => { expect(aiGatewayCalls).toEqual({ "text-to-image": 1 }); }); - it("should call the AI Gateway for generate API /image-to-image", async () => { - const res = await client.fetch("/beta/generate/image-to-image", { + it(`should call the AI Gateway for ${basePath}/image-to-image`, async () => { + const res = await client.fetch(`${basePath}/image-to-image`, { method: "POST", body: buildMultipartBody({ prompt: "replace the suit with a bathing suit", @@ -189,8 +199,8 @@ describe("controllers/generate", () => { expect(aiGatewayCalls).toEqual({ "image-to-image": 1 }); }); - it("should call the AI Gateway for generate API /image-to-video", async () => { - const res = await client.fetch("/beta/generate/image-to-video", { + it(`should call the AI Gateway for ${basePath}/image-to-video`, async () => { + const res = await client.fetch(`${basePath}/image-to-video`, { method: "POST", body: buildMultipartBody({}), }); @@ -202,8 +212,8 @@ describe("controllers/generate", () => { expect(aiGatewayCalls).toEqual({ "image-to-video": 1 }); }); - it("should call the AI Gateway for generate API /upscale", async () => { - const res = await client.fetch("/beta/generate/upscale", { + it(`should call the AI Gateway for ${basePath}/upscale`, async () => { + const res = await client.fetch(`${basePath}/upscale`, { method: "POST", body: buildMultipartBody({ prompt: "enhance" }), }); @@ -215,8 +225,8 @@ describe("controllers/generate", () => { expect(aiGatewayCalls).toEqual({ upscale: 1 }); }); - it("should call the AI Gateway for generate API /segment-anything-2", async () => { - const res = await client.fetch("/beta/generate/segment-anything-2", { + it(`should call the AI Gateway for ${basePath}/segment-anything-2`, async () => { + const res = await client.fetch(`${basePath}/segment-anything-2`, { method: "POST", body: buildMultipartBody({}), }); @@ -260,7 +270,7 @@ describe("controllers/generate", () => { for (const [title, input, error] of testCases) { it(title, async () => { - const res = await client.fetch("/beta/generate/image-to-image", { + const res = await client.fetch("/generate/image-to-image", { method: "POST", body: input, }); @@ -287,7 +297,7 @@ describe("controllers/generate", () => { } it("should log all requests to db", async () => { - const res = await client.post("/beta/generate/text-to-image", { + const res = await client.post("/generate/text-to-image", { prompt: "a man in a suit and tie", }); expect(res.status).toBe(200); @@ -325,7 +335,7 @@ describe("controllers/generate", () => { `{"details":{"msg":"sudden error"}}`, ); - const res = await client.post("/beta/generate/text-to-image", { + const res = await client.post("/generate/text-to-image", { prompt: "a man in a suit and tie", }); expect(res.status).toBe(500); @@ -345,7 +355,7 @@ describe("controllers/generate", () => { it("should log non JSON outputs as strings to db", async () => { mockFetchHttpError(418, "text/plain", `I'm not Jason`); - const res = await client.post("/beta/generate/text-to-image", { + const res = await client.post("/generate/text-to-image", { prompt: "a man in a suit and tie", }); expect(res.status).toBe(418); @@ -364,7 +374,7 @@ describe("controllers/generate", () => { mockedFetchWithTimeout.mockImplementation(() => { throw new Error("on your face"); }); - const res = await client.post("/beta/generate/text-to-image", { + const res = await client.post("/generate/text-to-image", { prompt: "a man in a suit and tie", }); expect(res.status).toBe(500); @@ -394,10 +404,10 @@ describe("controllers/generate", () => { const makeAiGenReq = (pipeline: (typeof pipelines)[number]) => pipeline === "text-to-image" - ? client.post(`/beta/generate/${pipeline}`, { + ? client.post(`/generate/${pipeline}`, { prompt: "whatever you feel like", }) - : client.fetch(`/beta/generate/${pipeline}`, { + : client.fetch(`/generate/${pipeline}`, { method: "POST", body: buildMultipartBody( pipeline === "image-to-video" ? {} : { prompt: "make magic" }, diff --git a/packages/api/src/controllers/generate.ts b/packages/api/src/controllers/generate.ts index f9a02c85e..3016e14b2 100644 --- a/packages/api/src/controllers/generate.ts +++ b/packages/api/src/controllers/generate.ts @@ -35,8 +35,6 @@ const aiGenerateDurationMetric = new promclient.Histogram({ const app = Router(); -app.use(experimentSubjectsOnly("ai-generate")); - const rateLimiter: RequestHandler = async (req, res, next) => { const now = Date.now(); const [[{ count, min }]] = await db.aiGenerateLog.find( diff --git a/packages/api/src/controllers/index.ts b/packages/api/src/controllers/index.ts index f9863cd7f..bb0332a6a 100644 --- a/packages/api/src/controllers/index.ts +++ b/packages/api/src/controllers/index.ts @@ -34,6 +34,8 @@ export default { "api-token": apiToken, asset, auth, + generate, + // TODO: Remove beta paths "beta/generate": generate, broadcaster, clip, diff --git a/packages/api/src/schema/ai-api-schema.yaml b/packages/api/src/schema/ai-api-schema.yaml index 7dbd0284d..fd3f90288 100644 --- a/packages/api/src/schema/ai-api-schema.yaml +++ b/packages/api/src/schema/ai-api-schema.yaml @@ -1,6 +1,6 @@ openapi: 3.1.0 paths: - /api/beta/generate/text-to-image: + /api/generate/text-to-image: post: tags: - generate @@ -60,7 +60,7 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: textToImage - /api/beta/generate/image-to-image: + /api/generate/image-to-image: post: tags: - generate @@ -120,7 +120,7 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: imageToImage - /api/beta/generate/image-to-video: + /api/generate/image-to-video: post: tags: - generate @@ -180,7 +180,7 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: imageToVideo - /api/beta/generate/upscale: + /api/generate/upscale: post: tags: - generate @@ -240,7 +240,7 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: upscale - /api/beta/generate/audio-to-text: + /api/generate/audio-to-text: post: tags: - generate @@ -308,7 +308,7 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: audioToText - /api/beta/generate/segment-anything-2: + /api/generate/segment-anything-2: post: tags: - generate @@ -368,6 +368,65 @@ paths: schema: $ref: '#/components/schemas/studio-api-error' x-speakeasy-name-override: segmentAnything2 + /api/generate/llm: + post: + tags: + - generate + summary: LLM + description: Generate text using a language model. + operationId: genLLM + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genLLM' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/LLMResponse' + '400': + description: Bad Request + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '401': + description: Unauthorized + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + '422': + description: Validation Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPValidationError' + - $ref: '#/components/schemas/studio-api-error' + '500': + description: Internal Server Error + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/HTTPError' + - $ref: '#/components/schemas/studio-api-error' + default: + description: Error + content: + application/json: + schema: + $ref: '#/components/schemas/studio-api-error' + x-speakeasy-name-override: llm components: schemas: APIError: @@ -414,6 +473,14 @@ components: title: Model Id description: Hugging Face model ID used for image generation. default: timbrooks/instruct-pix2pix + loras: + type: string + title: Loras + description: >- + A LoRA (Low-Rank Adaptation) model and its corresponding weight for + image generation. Example: { "latent-consistency/lcm-lora-sdxl": + 1.0, "nerijs/pixel-art-xl": 1.2}. + default: '' strength: type: number title: Strength @@ -533,6 +600,41 @@ components: - image title: Body_genImageToVideo additionalProperties: false + Body_genLLM: + properties: + prompt: + type: string + title: Prompt + model_id: + type: string + title: Model Id + default: '' + system_msg: + type: string + title: System Msg + default: '' + temperature: + type: number + title: Temperature + default: 0.7 + max_tokens: + type: integer + title: Max Tokens + default: 256 + history: + type: string + title: History + default: '[]' + stream: + type: boolean + title: Stream + default: false + type: object + required: + - prompt + - model_id + title: Body_genLLM + additionalProperties: false Body_genSegmentAnything2: properties: image: @@ -544,7 +646,7 @@ components: type: string title: Model Id description: Hugging Face model ID used for image generation. - default: 'facebook/sam2-hiera-large' + default: facebook/sam2-hiera-large point_coords: type: string title: Point Coords @@ -667,6 +769,19 @@ components: - images title: ImageResponse description: Response model for image generation. + LLMResponse: + properties: + response: + type: string + title: Response + tokens_used: + type: integer + title: Tokens Used + type: object + required: + - response + - tokens_used + title: LLMResponse MasksResponse: properties: masks: @@ -734,6 +849,14 @@ components: title: Model Id description: Hugging Face model ID used for image generation. default: SG161222/RealVisXL_V4.0_Lightning + loras: + type: string + title: Loras + description: >- + A LoRA (Low-Rank Adaptation) model and its corresponding weight for + image generation. Example: { "latent-consistency/lcm-lora-sdxl": + 1.0, "nerijs/pixel-art-xl": 1.2}. + default: '' prompt: type: string title: Prompt diff --git a/packages/api/src/schema/pull-ai-schema.js b/packages/api/src/schema/pull-ai-schema.js index 02645051e..9b95937ef 100644 --- a/packages/api/src/schema/pull-ai-schema.js +++ b/packages/api/src/schema/pull-ai-schema.js @@ -62,8 +62,8 @@ const downloadAiSchema = async () => { // patches to the paths section schema.paths = mapObject(schema.paths, (path, value) => { - // prefix paths with /api/beta/generate - path = `/api/beta/generate${path}`; + // prefix paths with /api/generate + path = `/api/generate${path}`; // remove security field delete value.post.security; // add Studio API error as oneOf to all of the error responses