Skip to content

Commit

Permalink
ai: generate: remove endpoint from experiment & remove beta from path (
Browse files Browse the repository at this point in the history
…#2318)

* ai: generate: remove endpoint from experiment & remove beta from path

* revert broken schema

* remove beta path from schema

* generate schema & remove unused middleware
  • Loading branch information
gioelecerati authored Oct 3, 2024
1 parent 6715c2a commit af93493
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 31 deletions.
50 changes: 30 additions & 20 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ afterEach(async () => {
await clearDatabase(server);
});

const testBothRoutes = (testFn) => {
describe("generate route", () => {
testFn("/generate");
});

describe("beta generate route", () => {
testFn("/beta/generate");
});
};

describe("controllers/generate", () => {
let client: TestClient;
let adminUser: User;
Expand Down Expand Up @@ -145,9 +155,9 @@ describe("controllers/generate", () => {
return form;
};

describe("API proxies", () => {
it("should call the AI Gateway for generate API /audio-to-text", async () => {
const res = await client.fetch("/beta/generate/audio-to-text", {
testBothRoutes((basePath) => {
it(`should call the AI Gateway for ${basePath}/audio-to-text`, async () => {
const res = await client.fetch(`${basePath}/audio-to-text`, {
method: "POST",
body: buildMultipartBody(
{},
Expand All @@ -162,8 +172,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "audio-to-text": 1 });
});

it("should call the AI Gateway for generate API /text-to-image", async () => {
const res = await client.post("/beta/generate/text-to-image", {
it(`should call the AI Gateway for ${basePath}/text-to-image`, async () => {
const res = await client.post(`${basePath}/text-to-image`, {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand All @@ -174,8 +184,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-image", async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
it(`should call the AI Gateway for ${basePath}/image-to-image`, async () => {
const res = await client.fetch(`${basePath}/image-to-image`, {
method: "POST",
body: buildMultipartBody({
prompt: "replace the suit with a bathing suit",
Expand All @@ -189,8 +199,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-video", async () => {
const res = await client.fetch("/beta/generate/image-to-video", {
it(`should call the AI Gateway for ${basePath}/image-to-video`, async () => {
const res = await client.fetch(`${basePath}/image-to-video`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand All @@ -202,8 +212,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 });
});

it("should call the AI Gateway for generate API /upscale", async () => {
const res = await client.fetch("/beta/generate/upscale", {
it(`should call the AI Gateway for ${basePath}/upscale`, async () => {
const res = await client.fetch(`${basePath}/upscale`, {
method: "POST",
body: buildMultipartBody({ prompt: "enhance" }),
});
Expand All @@ -215,8 +225,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ upscale: 1 });
});

it("should call the AI Gateway for generate API /segment-anything-2", async () => {
const res = await client.fetch("/beta/generate/segment-anything-2", {
it(`should call the AI Gateway for ${basePath}/segment-anything-2`, async () => {
const res = await client.fetch(`${basePath}/segment-anything-2`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand Down Expand Up @@ -260,7 +270,7 @@ describe("controllers/generate", () => {

for (const [title, input, error] of testCases) {
it(title, async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
const res = await client.fetch("/generate/image-to-image", {
method: "POST",
body: input,
});
Expand All @@ -287,7 +297,7 @@ describe("controllers/generate", () => {
}

it("should log all requests to db", async () => {
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand Down Expand Up @@ -325,7 +335,7 @@ describe("controllers/generate", () => {
`{"details":{"msg":"sudden error"}}`,
);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand All @@ -345,7 +355,7 @@ describe("controllers/generate", () => {
it("should log non JSON outputs as strings to db", async () => {
mockFetchHttpError(418, "text/plain", `I'm not Jason`);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(418);
Expand All @@ -364,7 +374,7 @@ describe("controllers/generate", () => {
mockedFetchWithTimeout.mockImplementation(() => {
throw new Error("on your face");
});
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand Down Expand Up @@ -394,10 +404,10 @@ describe("controllers/generate", () => {

const makeAiGenReq = (pipeline: (typeof pipelines)[number]) =>
pipeline === "text-to-image"
? client.post(`/beta/generate/${pipeline}`, {
? client.post(`/generate/${pipeline}`, {
prompt: "whatever you feel like",
})
: client.fetch(`/beta/generate/${pipeline}`, {
: client.fetch(`/generate/${pipeline}`, {
method: "POST",
body: buildMultipartBody(
pipeline === "image-to-video" ? {} : { prompt: "make magic" },
Expand Down
2 changes: 0 additions & 2 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ const aiGenerateDurationMetric = new promclient.Histogram({

const app = Router();

app.use(experimentSubjectsOnly("ai-generate"));

const rateLimiter: RequestHandler = async (req, res, next) => {
const now = Date.now();
const [[{ count, min }]] = await db.aiGenerateLog.find(
Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/controllers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ export default {
"api-token": apiToken,
asset,
auth,
generate,
// TODO: Remove beta paths
"beta/generate": generate,
broadcaster,
clip,
Expand Down
137 changes: 130 additions & 7 deletions packages/api/src/schema/ai-api-schema.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
openapi: 3.1.0
paths:
/api/beta/generate/text-to-image:
/api/generate/text-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -60,7 +60,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: textToImage
/api/beta/generate/image-to-image:
/api/generate/image-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -120,7 +120,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToImage
/api/beta/generate/image-to-video:
/api/generate/image-to-video:
post:
tags:
- generate
Expand Down Expand Up @@ -180,7 +180,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToVideo
/api/beta/generate/upscale:
/api/generate/upscale:
post:
tags:
- generate
Expand Down Expand Up @@ -240,7 +240,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: upscale
/api/beta/generate/audio-to-text:
/api/generate/audio-to-text:
post:
tags:
- generate
Expand Down Expand Up @@ -308,7 +308,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: audioToText
/api/beta/generate/segment-anything-2:
/api/generate/segment-anything-2:
post:
tags:
- generate
Expand Down Expand Up @@ -368,6 +368,65 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: segmentAnything2
/api/generate/llm:
post:
tags:
- generate
summary: LLM
description: Generate text using a language model.
operationId: genLLM
requestBody:
content:
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/Body_genLLM'
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/LLMResponse'
'400':
description: Bad Request
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
'401':
description: Unauthorized
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
'422':
description: Validation Error
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPValidationError'
- $ref: '#/components/schemas/studio-api-error'
'500':
description: Internal Server Error
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
default:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: llm
components:
schemas:
APIError:
Expand Down Expand Up @@ -414,6 +473,14 @@ components:
title: Model Id
description: Hugging Face model ID used for image generation.
default: timbrooks/instruct-pix2pix
loras:
type: string
title: Loras
description: >-
A LoRA (Low-Rank Adaptation) model and its corresponding weight for
image generation. Example: { "latent-consistency/lcm-lora-sdxl":
1.0, "nerijs/pixel-art-xl": 1.2}.
default: ''
strength:
type: number
title: Strength
Expand Down Expand Up @@ -533,6 +600,41 @@ components:
- image
title: Body_genImageToVideo
additionalProperties: false
Body_genLLM:
properties:
prompt:
type: string
title: Prompt
model_id:
type: string
title: Model Id
default: ''
system_msg:
type: string
title: System Msg
default: ''
temperature:
type: number
title: Temperature
default: 0.7
max_tokens:
type: integer
title: Max Tokens
default: 256
history:
type: string
title: History
default: '[]'
stream:
type: boolean
title: Stream
default: false
type: object
required:
- prompt
- model_id
title: Body_genLLM
additionalProperties: false
Body_genSegmentAnything2:
properties:
image:
Expand All @@ -544,7 +646,7 @@ components:
type: string
title: Model Id
description: Hugging Face model ID used for image generation.
default: 'facebook/sam2-hiera-large'
default: facebook/sam2-hiera-large
point_coords:
type: string
title: Point Coords
Expand Down Expand Up @@ -667,6 +769,19 @@ components:
- images
title: ImageResponse
description: Response model for image generation.
LLMResponse:
properties:
response:
type: string
title: Response
tokens_used:
type: integer
title: Tokens Used
type: object
required:
- response
- tokens_used
title: LLMResponse
MasksResponse:
properties:
masks:
Expand Down Expand Up @@ -734,6 +849,14 @@ components:
title: Model Id
description: Hugging Face model ID used for image generation.
default: SG161222/RealVisXL_V4.0_Lightning
loras:
type: string
title: Loras
description: >-
A LoRA (Low-Rank Adaptation) model and its corresponding weight for
image generation. Example: { "latent-consistency/lcm-lora-sdxl":
1.0, "nerijs/pixel-art-xl": 1.2}.
default: ''
prompt:
type: string
title: Prompt
Expand Down
4 changes: 2 additions & 2 deletions packages/api/src/schema/pull-ai-schema.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ const downloadAiSchema = async () => {

// patches to the paths section
schema.paths = mapObject(schema.paths, (path, value) => {
// prefix paths with /api/beta/generate
path = `/api/beta/generate${path}`;
// prefix paths with /api/generate
path = `/api/generate${path}`;
// remove security field
delete value.post.security;
// add Studio API error as oneOf to all of the error responses
Expand Down

0 comments on commit af93493

Please sign in to comment.