Skip to content

Commit

Permalink
Unify all invokers
Browse files Browse the repository at this point in the history
  • Loading branch information
DennisTraub authored and rlhagerm committed Mar 20, 2024
1 parent ff80bdb commit ea3667b
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,83 +1,89 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

/**
* @typedef {(prompt: string, modelId: string) => Promise<string>} Invoker
*
* @typedef {{ invokeModel: Invoker }} Module
*/

export const FoundationModels = Object.freeze({
CLAUDE_3_HAIKU: {
modelId: "anthropic.claude-3-haiku-20240307-v1:0",
modelName: "Anthropic Claude 3 Haiku",
module: () => import("../models/anthropic_claude/claude_3.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
CLAUDE_3_SONNET: {
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
modelName: "Anthropic Claude 3 Sonnet",
module: () => import("../models/anthropic_claude/claude_3.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
CLAUDE_2_1: {
modelId: "anthropic.claude-v2:1",
modelName: "Anthropic Claude 2.1",
module: () => import("../models/anthropic_claude/claude_2.js"),
invoker: (/** @type {Function} */ module) => module.invokeMessagesApi,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
CLAUDE_2: {
modelId: "anthropic.claude-v2",
modelName: "Anthropic Claude 2.0",
module: () => import("../models/anthropic_claude/claude_2.js"),
invoker: (/** @type {Function} */ module) => module.invokeMessagesApi,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
CLAUDE_INSTANT: {
modelId: "anthropic.claude-instant-v1",
modelName: "Anthropic Claude Instant",
module: () => import("../models/anthropic_claude/claude_instant_1.js"),
invoker: (/** @type {Function} */ module) => module.invokeMessagesApi,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
JURASSIC2_MID: {
modelId: "ai21.j2-mid-v1",
modelName: "Jurassic-2 Mid",
module: () => import("../models/ai21_labs_jurassic2/jurassic2.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
JURASSIC2_ULTRA: {
modelId: "ai21.j2-ultra-v1",
modelName: "Jurassic-2 Ultra",
module: () => import("../models/ai21_labs_jurassic2/jurassic2.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
LLAMA2_CHAT_13B: {
modelId: "meta.llama2-13b-chat-v1",
modelName: "Llama 2 Chat 13B",
module: () => import("../models/meta_llama2/llama2_chat.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
LLAMA2_CHAT_70B: {
modelId: "meta.llama2-70b-chat-v1",
modelName: "Llama 2 Chat 70B",
module: () => import("../models/meta_llama2/llama2_chat.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
MISTRAL_7B: {
modelId: "mistral.mistral-7b-instruct-v0:2",
modelName: "Mistral 7B Instruct",
module: () => import("../models/mistral_ai/mistral_7b.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
MIXTRAL_8X7B: {
modelId: "mistral.mixtral-8x7b-instruct-v0:1",
modelName: "Mixtral 8X7B Instruct",
module: () => import("../models/mistral_ai/mixtral_8x7b.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
TITAN_TEXT_G1_EXPRESS: {
modelId: "amazon.titan-text-express-v1",
modelName: "Titan Text G1 - Express",
module: () => import("../models/amazon_titan/titan_text.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
TITAN_TEXT_G1_LITE: {
modelId: "amazon.titan-text-lite-v1",
modelName: "Titan Text G1 - Lite",
module: () => import("../models/amazon_titan/titan_text.js"),
invoker: (/** @type {Function} */ module) => module.invokeModel,
invoker: (/** @type {Module} */ module) => module.invokeModel,
},
});
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export const invokeModel = async (prompt, modelId = "ai21.j2-mid-v1") => {
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {ResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.completions.map((completion) => completion.data.text);
return responseBody.completions[0].data.text;
};

// Invoke the function if this file was run directly.
Expand All @@ -62,8 +62,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {

try {
console.log("-".repeat(53));
const responses = await invokeModel(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const invokeModel = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {ResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.results.map((result) => result.outputText);
return responseBody.results[0].outputText;
};

// Invoke the function if this file was run directly.
Expand All @@ -63,8 +63,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {

try {
console.log("-".repeat(53));
const responses = await invokeModel(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ import {
* @param {string} prompt - The input text prompt for the model to complete.
* @param {string} [modelId] - The ID of the model to use. Defaults to "anthropic.claude-v2".
*/
export const invokeMessagesApi = async (
prompt,
modelId = "anthropic.claude-v2",
) => {
export const invokeModel = async (prompt, modelId = "anthropic.claude-v2") => {
// Create a new Bedrock Runtime client instance.
const client = new BedrockRuntimeClient({ region: "us-east-1" });

Expand Down Expand Up @@ -64,7 +61,7 @@ export const invokeMessagesApi = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {MessagesResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.content.map((content) => content.text);
return responseBody.content[0].text;
};

/**
Expand Down Expand Up @@ -118,8 +115,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {
try {
console.log("-".repeat(53));
console.log("Using the Messages API:");
const responses = await invokeMessagesApi(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export const invokeModel = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {MessagesResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.content.map((content) => content.text);
return responseBody.content[0].text;
};

/**
Expand Down Expand Up @@ -108,29 +108,23 @@ export const invokeModelWithResponseStream = async (
});
const apiResponse = await client.send(command);

let role;
let final_message = "";
let completeMessage = "";

// Decode and process the response stream
for await (const item of apiResponse.body) {
/** @type Chunk */
const chunk = JSON.parse(new TextDecoder().decode(item.chunk.bytes));
const chunk_type = chunk.type;

if (chunk_type === "message_start") {
role = chunk.message.role;
} else if (chunk_type === "content_block_delta") {
if (chunk_type === "content_block_delta") {
const text = chunk.delta.text;
final_message = final_message + text;
completeMessage = completeMessage + text;
process.stdout.write(text);
}
}

// Return the final response
return {
role: role,
content: [{ type: "text", text: final_message }],
};
return completeMessage;
};

// Invoke the function if this file was run directly.
Expand All @@ -142,7 +136,7 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {

try {
console.log("-".repeat(53));
const response = await invokeModelWithResponseStream(prompt, modelId);
const response = await invokeModel(prompt, modelId);
console.log("\n" + "-".repeat(53));
console.log("Final structured response:");
console.log(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {
* @param {string} prompt - The input text prompt for the model to complete.
* @param {string} [modelId] - The ID of the model to use. Defaults to "anthropic.claude-instant-v1".
*/
export const invokeMessagesApi = async (
export const invokeModel = async (
prompt,
modelId = "anthropic.claude-instant-v1",
) => {
Expand Down Expand Up @@ -62,7 +62,7 @@ export const invokeMessagesApi = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {MessageApiResponse} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.content.map((content) => content.text);
return responseBody.content[0].text;
};

/**
Expand Down Expand Up @@ -116,8 +116,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {
try {
console.log("-".repeat(53));
console.log("Using the Messages API:");
const responses = await invokeMessagesApi(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export const invokeModel = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {ResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.outputs.map((output) => output.text);
return responseBody.outputs[0].text;
};

// Invoke the function if this file was run directly.
Expand All @@ -66,8 +66,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {

try {
console.log("-".repeat(53));
const responses = await invokeModel(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export const invokeModel = async (
const decodedResponseBody = new TextDecoder().decode(apiResponse.body);
/** @type {ResponseBody} */
const responseBody = JSON.parse(decodedResponseBody);
return responseBody.outputs.map((output) => output.text);
return responseBody.outputs[0].text;
};

// Invoke the function if this file was run directly.
Expand All @@ -66,8 +66,8 @@ if (process.argv[1] === fileURLToPath(import.meta.url)) {

try {
console.log("-".repeat(53));
const responses = await invokeModel(prompt, modelId);
responses.forEach((response) => console.log(response));
const response = await invokeModel(prompt, modelId);
console.log(response);
} catch (err) {
console.log(err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ describe("Invoke Jurassic2 Mid", () => {
it("should return a response", async () => {
const modelId = FoundationModels.JURASSIC2_MID.modelId;
const response = await invokeModel(TEXT_PROMPT, modelId);
expectToBeANonEmptyString(response[0]);
expectToBeANonEmptyString(response);
});
});

describe("Invoke Jurassic2 Ultra", () => {
it("should return a response", async () => {
const modelId = FoundationModels.JURASSIC2_ULTRA.modelId;
const response = await invokeModel(TEXT_PROMPT, modelId);
expectToBeANonEmptyString(response[0]);
expectToBeANonEmptyString(response);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ describe("Invoke Titan Text G1 - Express", () => {
it("should return a response", async () => {
const modelId = FoundationModels.TITAN_TEXT_G1_EXPRESS.modelId;
const response = await invokeModel(TEXT_PROMPT, modelId);
expectToBeANonEmptyString(response[0]);
expectToBeANonEmptyString(response);
});
});

describe("Invoke Titan Text G1 - Lite", () => {
it("should return a response", async () => {
const modelId = FoundationModels.TITAN_TEXT_G1_LITE.modelId;
const response = await invokeModel(TEXT_PROMPT, modelId);
expectToBeANonEmptyString(response[0]);
expectToBeANonEmptyString(response);
});
});
Loading

0 comments on commit ea3667b

Please sign in to comment.