Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add similarity search batch processing and tests #437

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tools/similarity_search/.prettierrc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"trailingComma": "all",
"tabWidth": 2,
"semi": false,
"singleQuote": false
}
1 change: 1 addition & 0 deletions tools/similarity_search/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
},
"devDependencies": {
"@cloudflare/workers-types": "^4.20240403.0",
"prettier": "^3.2.5",
"wrangler": "^3.47.0"
}
}
7 changes: 7 additions & 0 deletions tools/similarity_search/src/env.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
type Env = {
API_GATEWAY_UNIVERSAL_API: string
API_KEY_TOKEN_CHECK: string
WORKERS_API_KEY: string
VECTORIZE_INDEX: VectorizeIndex
MAX_INPUT: number
}
122 changes: 103 additions & 19 deletions tools/similarity_search/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { Hono } from "hono"

type Env = {
API_KEY_TOKEN_CHECK: string
AI: Ai
VECTORIZE_INDEX: VectorizeIndex
}
import { type AiTextEmbeddingsOutput } from "@cloudflare/workers-types/experimental"

type TextEntry = {
text: string
text: string | string[]
namespace: string
}
type GatewayAiResponse = {
result: AiTextEmbeddingsOutput
success: boolean
errors: Array<{
message: string
code: number
}>
messages: Array<unknown>
}

const app = new Hono<{ Bindings: Env }>()

Expand All @@ -28,24 +32,104 @@ app.use("*", async (c, next) => {
})

app.post("/", async (c) => {
const data = await c.req.json<TextEntry>()
// Format https://gateway.ai.cloudflare.com/v1/{ACCOUNT_ID}/{SLUG}/
const apiGatewayUniversalApi = c.env.API_GATEWAY_UNIVERSAL_API
if (!apiGatewayUniversalApi) {
return c.text("Missing gateway URL", 500)
}

const workersApikey = c.env.WORKERS_API_KEY
if (!workersApikey) {
return c.text("Missing workers API key", 500)
}

let data: TextEntry
try {
data = await c.req.json<TextEntry>()
} catch (error) {
return c.text("Cannot parse JSON input", 400)
}

const { text, namespace } = data

if (typeof text !== "string" || typeof namespace !== "string") {
return c.text("Invalid JSON format", 400)
let texts: string[]
if (typeof text === "string" && text.length) {
texts = [text]
} else if (
Array.isArray(text) &&
text.every((element) => typeof element === "string" && element.length)
) {
texts = text
} else {
return c.text(
"Invalid JSON format, property `text` must be a non-empty string or array of non-empty strings",
400,
)
}
const MAX_INPUT = Number(c.env.MAX_INPUT) || 100
if (texts.length > MAX_INPUT) {
return c.text(
`Too big input, property \`text\` can have max ${MAX_INPUT} items`,
400,
)
}
if (typeof namespace !== "string") {
return c.text(
"Invalid JSON format, property `namespace` must be a string",
400,
)
}

const modelResp = await c.env.AI.run("@cf/baai/bge-base-en-v1.5", {
text: [text]
})
const vector = modelResp.data[0]
const searchResponse = await c.env.VECTORIZE_INDEX.query(vector, {
namespace,
topK: 1
// resolve each text individually to enable cache per request
let index = 0
const requests = texts.map(async (text) => {
const response = await fetch(
`${apiGatewayUniversalApi}workers-ai/@cf/baai/bge-base-en-v1.5`,
{
method: "POST",
headers: {
Authorization: `Bearer ${workersApikey}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
text,
}),
},
)
const json = await response.json() as GatewayAiResponse
if (!json?.success || !json?.result?.data) {
console.error(`Workers AI error at index ${index}`, json)
throw new Error(`Workers AI error for text index ${index}`)
}

index++

const vector = json.result?.data?.[0]
return await c.env.VECTORIZE_INDEX.query(vector, {
namespace,
topK: 1,
})
})
const similarityScore = searchResponse.matches[0]?.score || 0

return c.json({ similarity_score: similarityScore })
let responses
try {
responses = await Promise.all(requests)
} catch (error) {
console.error(`Batch error - ${error}`)
return c.text(`An error occurred - ${error}`, 500)
}

const similarityScores = []
for (const searchResponse of responses) {
const similarityScore = searchResponse.matches[0]?.score || 0
similarityScores.push(similarityScore)
}

if (typeof text === "string") {
return c.json({ similarity_score: similarityScores[0] })
} else {
return c.json({ similarity_score: similarityScores })
}
})

export default app
3 changes: 3 additions & 0 deletions tools/similarity_search/test/env.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
declare module "cloudflare:test" {
interface ProvidedEnv extends Env {}
}
75 changes: 70 additions & 5 deletions tools/similarity_search/test/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { SELF } from "cloudflare:test"
import { describe, it, expect } from "vitest"
import { SELF, env } from "cloudflare:test"
import { describe, it, expect, vi } from "vitest"

import "../src/index"

Expand All @@ -8,15 +8,80 @@ describe("Authentication", () => {
const response = await SELF.fetch("https://example.com/", {
method: "POST",
headers: {
"Content-Type": "application/json"
"Content-Type": "application/json",
},
body: JSON.stringify({
text: "Sample text",
namespace: "test-namespace"
})
namespace: "test-namespace",
}),
})

expect(response.status).toBe(401)
expect(await response.text()).toBe("Unauthorized")
})
})

describe("Single message processing", () => {
it("returns single scalar result when single scalar text is given", async () => {
const response = await SELF.fetch("https://example.com/", {
method: "POST",
headers: {
"Content-Type": "application/json",
"X-API-Key": "test-api-key",
},
body: JSON.stringify({
text: "Sample text",
namespace: "test-namespace",
}),
})
expect(response.status).toBe(200)
expect(await response.text()).toEqual('{"similarity_score":0.5678}')
})
})

describe("Batch message processing", () => {
it("limits max inputs", async () => {
const response = await SELF.fetch("https://example.com/", {
method: "POST",
headers: {
"Content-Type": "application/json",
"X-API-Key": "test-api-key",
},
body: JSON.stringify({
text: [
"This is a story about an orange cloud",
"This is a story about a llama",
"This is a story about a hugging emoji",
"This is a story about overwhelming courage",
],
namespace: "test-namespace",
}),
})
expect(response.status).toBe(400)
expect(await response.text()).toEqual(
"Too big input, property `text` can have max 3 items",
)
})

it("returns array results when multiple texts are given", async () => {
const response = await SELF.fetch("https://example.com/", {
method: "POST",
headers: {
"Content-Type": "application/json",
"X-API-Key": "test-api-key",
},
body: JSON.stringify({
text: [
"This is a story about an orange cloud",
"This is a story about a llama",
"This is a story about a hugging emoji",
],
namespace: "test-namespace",
}),
})
expect(response.status).toBe(200)
expect(await response.text()).toEqual(
'{"similarity_score":[0.5678,0.5678,0.5678]}',
)
})
})
20 changes: 10 additions & 10 deletions tools/similarity_search/test/tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"extends": "../tsconfig.json",
"compilerOptions": {
"moduleResolution": "bundler",
"types": [
"@cloudflare/workers-types/experimental",
"@cloudflare/vitest-pool-workers"
]
},
"include": ["./**/*.ts", "../src/env.d.ts"]
}
"extends": "../tsconfig.json",
"compilerOptions": {
"moduleResolution": "bundler",
"types": [
"@cloudflare/workers-types/experimental",
"@cloudflare/vitest-pool-workers"
]
},
"include": ["./**/*.ts", "../src/env.d.ts"]
}
3 changes: 2 additions & 1 deletion tools/similarity_search/vitest.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export default defineWorkersConfig({
},
miniflare: {
bindings: {
API_KEY_TOKEN_CHECK: "test-api-key"
API_KEY_TOKEN_CHECK: "test-api-key",
MAX_INPUT: 3
},
wrappedBindings: {
AI: {
Expand Down
9 changes: 8 additions & 1 deletion tools/similarity_search/wrangler.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ compatibility_flags = ["nodejs_compat"]
# database_id = ""

# [ai]
# binding = "AI"
# binding = "AI"

[[vectorize]]
binding = "VECTORIZE_INDEX"
index_name = "embeddings-index"

[limits]
cpu_ms = 50