Skip to content

Commit

Permalink
Anthropic Tool Support (#1594)
Browse files Browse the repository at this point in the history
* support anthropic PDF beta

* upstream merge, remove commented out console log line

* Fixing type errors.
the anthropic API does not yet include a "DocumentBlock" for
support PDFs, so an extended type has been added to the endpoint.

* changed document processor to async (matching image processor)

* use the beta api types rather than custom extension

* rudimentary tool testing

* interim commit (tool re-passing, file handling)

* remove merge error

* tidy up, isolate beta classes to utils

* anthropic tool calling support.

* improve handling of directlyAnswer tool

* fix streaming

* slight tidy up to tools flow handling

* fix: dont pass tools in final generation, instead deduce tools from tool results

---------

Co-authored-by: Nathan Sarrazin <[email protected]>
  • Loading branch information
evalstate and nsarrazin authored Jan 3, 2025
1 parent 18e264a commit c135e93
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 37 deletions.
146 changes: 128 additions & 18 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@ import type { Endpoint } from "../endpoints";
import { env } from "$env/dynamic/private";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator } from "../images";
import { endpointMessagesToAnthropicMessages } from "./utils";
import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils";
import { createDocumentProcessorOptionsValidator } from "../document";
import type {
Tool,
ToolCall,
ToolInput,
ToolInputFile,
ToolInputFixed,
ToolInputOptional,
} from "$lib/types/Tool";
import type Anthropic from "@anthropic-ai/sdk";
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
import directlyAnswer from "$lib/server/tools/directlyAnswer";

export const endpointAnthropicParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -52,23 +62,41 @@ export async function endpointAnthropic(
defaultQuery,
});

return async ({ messages, preprompt, generateSettings }) => {
return async ({
messages,
preprompt,
generateSettings,
conversationId,
tools = [],
toolResults = [],
}) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
}

let tokenId = 0;
if (tools.length === 0 && toolResults.length > 0) {
const toolNames = new Set(toolResults.map((tool) => tool.call.name));
tools = Array.from(toolNames).map((name) => ({
name,
description: "",
inputs: [],
})) as unknown as Tool[];
}

const parameters = { ...model.parameters, ...generateSettings };

return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: (await endpointMessagesToAnthropicMessages(
messages,
multimodal
)) as MessageParam[],
tools: createAnthropicTools(tools),
tool_choice:
tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined,
messages: addToolResults(
await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId),
toolResults
) as MessageParam[],
max_tokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
Expand All @@ -79,21 +107,40 @@ export async function endpointAnthropic(
while (true) {
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);

// Stream end
if (result === undefined) {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
if ("tool_use" === stream.receivedMessages[0].stop_reason) {
// this should really create a new "Assistant" message with the tool id in it.
const toolCalls: ToolCall[] = stream.receivedMessages[0].content
.filter(
(block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } =>
block.type === "tool_use"
)
.map((block) => ({
name: block.name,
parameters: block.input as Record<string, string | number | boolean>,
id: block.id,
}));

yield {
token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls },
generated_text: null,
details: null,
};
} else {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
}

return;
}

// Text delta
yield {
token: {
Expand All @@ -109,3 +156,66 @@ export async function endpointAnthropic(
})();
};
}

function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] {
return tools
.filter((tool) => tool.name !== directlyAnswer.name)
.map((tool) => {
const properties = tool.inputs.reduce((acc, input) => {
acc[input.name] = convertToolInputToJSONSchema(input);
return acc;
}, {} as Record<string, unknown>);

const required = tool.inputs
.filter((input) => input.paramType === "required")
.map((input) => input.name);

return {
name: tool.name,
description: tool.description,
input_schema: {
type: "object",
properties,
required: required.length > 0 ? required : undefined,
},
};
});
}

function convertToolInputToJSONSchema(input: ToolInput): Record<string, unknown> {
const baseSchema: Record<string, unknown> = {};
if ("description" in input) {
baseSchema["description"] = input.description || "";
}
switch (input.paramType) {
case "optional":
baseSchema["default"] = (input as ToolInputOptional).default;
break;
case "fixed":
baseSchema["const"] = (input as ToolInputFixed).value;
break;
}

if (input.type === "file") {
baseSchema["type"] = "string";
baseSchema["format"] = "binary";
baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes;
} else {
switch (input.type) {
case "str":
baseSchema["type"] = "string";
break;
case "int":
baseSchema["type"] = "integer";
break;
case "float":
baseSchema["type"] = "number";
break;
case "bool":
baseSchema["type"] = "boolean";
break;
}
}

return baseSchema;
}
68 changes: 56 additions & 12 deletions src/lib/server/endpoints/anthropic/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import type {
BetaMessageParam,
BetaBase64PDFBlock,
} from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
import type { ToolResult } from "$lib/types/Tool";
import { downloadFile } from "$lib/server/files/downloadFile";
import type { ObjectId } from "mongodb";

export async function fileToImageBlock(
file: MessageFile,
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
): Promise<BetaImageBlockParam> {
const processor = makeImageProcessor(opts);

const { image, mime } = await processor(file);

return {
Expand Down Expand Up @@ -48,7 +52,8 @@ export async function endpointMessagesToAnthropicMessages(
multimodal: {
image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
document?: FileProcessorOptions<"application/pdf">;
}
},
conversationId?: ObjectId | undefined
): Promise<BetaMessageParam[]> {
return await Promise.all(
messages
Expand All @@ -57,20 +62,59 @@ export async function endpointMessagesToAnthropicMessages(
return {
role: message.from,
content: [
...(await Promise.all(
(message.files ?? []).map(async (file) => {
if (file.mime.startsWith("image/")) {
return fileToImageBlock(file, multimodal.image);
} else if (file.mime === "application/pdf" && multimodal.document) {
return fileToDocumentBlock(file, multimodal.document);
} else {
throw new Error(`Unsupported file type: ${file.mime}`);
}
})
)),
...(message.from === "user"
? await Promise.all(
(message.files ?? []).map(async (file) => {
if (file.type === "hash" && conversationId) {
file = await downloadFile(file.value, conversationId);
}

if (file.mime.startsWith("image/")) {
return fileToImageBlock(file, multimodal.image);
} else if (file.mime === "application/pdf" && multimodal.document) {
return fileToDocumentBlock(file, multimodal.document);
} else {
throw new Error(`Unsupported file type: ${file.mime}`);
}
})
)
: []),
{ type: "text", text: message.content },
],
};
})
);
}

export function addToolResults(
messages: BetaMessageParam[],
toolResults: ToolResult[]
): BetaMessageParam[] {
const id = crypto.randomUUID();
if (toolResults.length === 0) {
return messages;
}
return [
...messages,
{
role: "assistant",
content: toolResults.map((result, index) => ({
type: "tool_use",
id: `tool_${index}_${id}`,
name: result.call.name,
input: result.call.parameters,
})),
},
{
role: "user",
content: toolResults.map((result, index) => ({
type: "tool_result",
tool_use_id: `tool_${index}_${id}`,
is_error: result.status === "error",
content: JSON.stringify(
result.status === "error" ? result.message : "outputs" in result ? result.outputs : ""
),
})),
},
];
}
6 changes: 4 additions & 2 deletions src/lib/server/textGeneration/generate.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ToolResult } from "$lib/types/Tool";
import type { ToolResult, Tool } from "$lib/types/Tool";
import {
MessageReasoningUpdateType,
MessageUpdateType,
Expand All @@ -16,7 +16,8 @@ type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: End
export async function* generate(
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
toolResults: ToolResult[],
preprompt?: string
preprompt?: string,
tools?: Tool[]
): AsyncIterable<MessageUpdate> {
// reasoning mode is false by default
let reasoning = false;
Expand All @@ -43,6 +44,7 @@ export async function* generate(
preprompt,
continueMessage: isContinue,
generateSettings: assistant?.generateSettings,
tools,
toolResults,
isMultimodal: model.multimodal,
conversationId: conv._id,
Expand Down
11 changes: 7 additions & 4 deletions src/lib/server/textGeneration/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
import type { TextGenerationContext } from "./types";
import type { ToolResult } from "$lib/types/Tool";
import { toolHasName } from "../tools/utils";
import directlyAnswer from "../tools/directlyAnswer";

async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, undefined, undefined> {
while (!done.aborted) {
Expand Down Expand Up @@ -73,11 +74,13 @@ async function* textGenerationWithoutTitle(
}

let toolResults: ToolResult[] = [];
let tools = model.tools ? await getTools(toolsPreference, ctx.assistant) : undefined;

if (model.tools) {
const tools = await getTools(toolsPreference, ctx.assistant);
const toolCallsRequired = tools.some((tool) => !toolHasName("directly_answer", tool));
if (toolCallsRequired) toolResults = yield* runTools(ctx, tools, preprompt);
if (tools) {
const toolCallsRequired = tools.some((tool) => !toolHasName(directlyAnswer.name, tool));
if (toolCallsRequired) {
toolResults = yield* runTools(ctx, tools, preprompt);
} else tools = undefined;
}

const processedMessages = await preprocessMessages(messages, webSearchResult, convId);
Expand Down
2 changes: 1 addition & 1 deletion src/lib/server/textGeneration/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ export async function* runTools(
}

// if we dont see a tool call in the first 25 chars, something is going wrong and we abort
if (rawText.length > 25 && !(rawText.includes("```json") || rawText.includes("{"))) {
if (rawText.length > 100 && !(rawText.includes("```json") || rawText.includes("{"))) {
return [];
}

Expand Down

0 comments on commit c135e93

Please sign in to comment.