diff --git a/packages/event-handler/package.json b/packages/event-handler/package.json index c99c3f4486..667cebc435 100644 --- a/packages/event-handler/package.json +++ b/packages/event-handler/package.json @@ -39,6 +39,16 @@ "default": "./lib/esm/appsync-events/index.js" } }, + "./bedrock-agent": { + "require": { + "types": "./lib/cjs/bedrock-agent/index.d.ts", + "default": "./lib/cjs/bedrock-agent/index.js" + }, + "import": { + "types": "./lib/esm/bedrock-agent/index.d.ts", + "default": "./lib/esm/bedrock-agent/index.js" + } + }, "./types": { "require": { "types": "./lib/cjs/types/index.d.ts", @@ -56,6 +66,10 @@ "./lib/cjs/appsync-events/index.d.ts", "./lib/esm/appsync-events/index.d.ts" ], + "bedrock-agent": [ + "./lib/cjs/bedrock-agent/index.d.ts", + "./lib/esm/bedrock-agent/index.d.ts" + ], "types": [ "./lib/cjs/types/index.d.ts", "./lib/esm/types/index.d.ts" diff --git a/packages/event-handler/src/bedrock-agent/BedrockAgentFunctionResolver.ts b/packages/event-handler/src/bedrock-agent/BedrockAgentFunctionResolver.ts new file mode 100644 index 0000000000..93b30f0ece --- /dev/null +++ b/packages/event-handler/src/bedrock-agent/BedrockAgentFunctionResolver.ts @@ -0,0 +1,227 @@ +import { EnvironmentVariablesService } from '@aws-lambda-powertools/commons'; +import type { Context } from 'aws-lambda'; +import type { + BedrockAgentFunctionResponse, + Configuration, + ParameterValue, + ResolverOptions, + ResponseOptions, + Tool, + ToolFunction, +} from '../types/bedrock-agent.js'; +import type { GenericLogger } from '../types/common.js'; +import { assertBedrockAgentFunctionEvent } from './utils.js'; + +export class BedrockAgentFunctionResolver { + readonly #tools: Map = new Map(); + readonly #envService: EnvironmentVariablesService; + readonly #logger: Pick; + + constructor(options?: ResolverOptions) { + this.#envService = new EnvironmentVariablesService(); + const alcLogLevel = this.#envService.get('AWS_LAMBDA_LOG_LEVEL'); + this.#logger = options?.logger ?? { + debug: alcLogLevel === 'DEBUG' ? console.debug : () => {}, + error: console.error, + warn: console.warn, + }; + } + + /** + * Register a tool function for the Bedrock Agent. + * + * This method registers a function that can be invoked by a Bedrock Agent. + * + * @example + * ```ts + * import { BedrockAgentFunctionResolver } from '@aws-lambda-powertools/event-handler/bedrock-agent-function'; + * + * const app = new BedrockAgentFunctionResolver(); + * + * app.tool(async (params) => { + * const { name } = params; + * return `Hello, ${name}!`; + * }, { + * name: 'greeting', + * description: 'Greets a person by name', + * }); + * + * export const handler = async (event, context) => + * app.resolve(event, context); + * ``` + * + * The method also works as a class method decorator: + * + * @example + * ```ts + * import { BedrockAgentFunctionResolver } from '@aws-lambda-powertools/event-handler/bedrock-agent-function'; + * + * const app = new BedrockAgentFunctionResolver(); + * + * class Lambda { + * @app.tool({ name: 'greeting', description: 'Greets a person by name' }) + * async greeting(params) { + * const { name } = params; + * return `Hello, ${name}!`; + * } + * + * async handler(event, context) { + * return app.resolve(event, context); + * } + * } + * + * const lambda = new Lambda(); + * export const handler = lambda.handler.bind(lambda); + * ``` + * + * @param fn - The tool function + * @param config - The configuration object for the tool + */ + public tool>( + fn: ToolFunction, + config: Configuration + ): undefined; + public tool>( + config: Configuration + ): MethodDecorator; + public tool>( + fnOrConfig: ToolFunction | Configuration, + config?: Configuration + ): MethodDecorator | undefined { + // When used as a method (not a decorator) + if (typeof fnOrConfig === 'function') { + this.#registerTool(fnOrConfig, config as Configuration); + return; + } + + // When used as a decorator + return (_target, _propertyKey, descriptor: PropertyDescriptor) => { + const toolFn = descriptor.value as ToolFunction; + this.#registerTool(toolFn, fnOrConfig); + return descriptor; + }; + } + + #registerTool>( + handler: ToolFunction, + config: Configuration + ): void { + const { name } = config; + + if (this.#tools.size >= 5) { + this.#logger.warn( + `The maximum number of tools that can be registered is 5. Tool ${name} will not be registered.` + ); + return; + } + + if (this.#tools.has(name)) { + this.#logger.warn( + `Tool ${name} already registered. Overwriting with new definition.` + ); + } + + this.#tools.set(name, { + handler: handler as ToolFunction, + config, + }); + this.#logger.debug(`Tool ${name} has been registered.`); + } + + #buildResponse(options: ResponseOptions): BedrockAgentFunctionResponse { + const { + actionGroup, + function: func, + body, + errorType, + sessionAttributes, + promptSessionAttributes, + } = options; + + return { + messageVersion: '1.0', + response: { + actionGroup, + function: func, + functionResponse: { + responseState: errorType, + responseBody: { + TEXT: { + body, + }, + }, + }, + }, + sessionAttributes, + promptSessionAttributes, + }; + } + + async resolve( + event: unknown, + context: Context + ): Promise { + assertBedrockAgentFunctionEvent(event); + + const { + function: toolName, + parameters = [], + actionGroup, + sessionAttributes, + promptSessionAttributes, + } = event; + + const tool = this.#tools.get(toolName); + + if (tool == null) { + this.#logger.error(`Tool ${toolName} has not been registered.`); + return this.#buildResponse({ + actionGroup, + function: toolName, + body: 'Error: tool has not been registered in handler.', + }); + } + + const toolParams: Record = {}; + for (const param of parameters) { + switch (param.type) { + case 'boolean': { + toolParams[param.name] = param.value === 'true'; + break; + } + case 'number': + case 'integer': { + toolParams[param.name] = Number(param.value); + break; + } + // this default will also catch array types but we leave them as strings + // because we cannot reliably parse them + default: { + toolParams[param.name] = param.value; + break; + } + } + } + + try { + const res = await tool.handler(toolParams, { event, context }); + const body = res == null ? '' : JSON.stringify(res); + return this.#buildResponse({ + actionGroup, + function: toolName, + body, + sessionAttributes, + promptSessionAttributes, + }); + } catch (error) { + this.#logger.error(`An error occurred in tool ${toolName}.`, error); + return this.#buildResponse({ + actionGroup, + function: toolName, + body: `Error when invoking tool: ${error}`, + sessionAttributes, + promptSessionAttributes, + }); + } + } +} diff --git a/packages/event-handler/src/bedrock-agent/index.ts b/packages/event-handler/src/bedrock-agent/index.ts new file mode 100644 index 0000000000..a18e9dd726 --- /dev/null +++ b/packages/event-handler/src/bedrock-agent/index.ts @@ -0,0 +1 @@ +export { BedrockAgentFunctionResolver } from './BedrockAgentFunctionResolver.js'; diff --git a/packages/event-handler/src/bedrock-agent/utils.ts b/packages/event-handler/src/bedrock-agent/utils.ts new file mode 100644 index 0000000000..c951502d59 --- /dev/null +++ b/packages/event-handler/src/bedrock-agent/utils.ts @@ -0,0 +1,55 @@ +import { isRecord, isString } from '@aws-lambda-powertools/commons/typeutils'; +import type { BedrockAgentFunctionEvent } from '../types/bedrock-agent.js'; + +/** + * Asserts that the provided event is a BedrockAgentFunctionEvent. + * + * @param event - The incoming event to check + * @throws Error if the event is not a valid BedrockAgentFunctionEvent + */ +export function assertBedrockAgentFunctionEvent( + event: unknown +): asserts event is BedrockAgentFunctionEvent { + const isValid = + isRecord(event) && + 'actionGroup' in event && + isString(event.actionGroup) && + 'function' in event && + isString(event.function) && + (!('parameters' in event) || + (Array.isArray(event.parameters) && + event.parameters.every( + (param) => + isRecord(param) && + 'name' in param && + isString(param.name) && + 'type' in param && + isString(param.type) && + 'value' in param && + isString(param.value) + ))) && + 'messageVersion' in event && + isString(event.messageVersion) && + 'agent' in event && + isRecord(event.agent) && + 'name' in event.agent && + isString(event.agent.name) && + 'id' in event.agent && + isString(event.agent.id) && + 'alias' in event.agent && + isString(event.agent.alias) && + 'version' in event.agent && + isString(event.agent.version) && + 'inputText' in event && + isString(event.inputText) && + 'sessionId' in event && + isString(event.sessionId) && + 'sessionAttributes' in event && + isRecord(event.sessionAttributes) && + 'promptSessionAttributes' in event && + isRecord(event.promptSessionAttributes); + + if (!isValid) { + throw new Error('Event is not a valid BedrockAgentFunctionEvent'); + } +} diff --git a/packages/event-handler/src/types/appsync-events.ts b/packages/event-handler/src/types/appsync-events.ts index 7c391f3312..1367cacd93 100644 --- a/packages/event-handler/src/types/appsync-events.ts +++ b/packages/event-handler/src/types/appsync-events.ts @@ -1,22 +1,7 @@ import type { Context } from 'aws-lambda'; import type { RouteHandlerRegistry } from '../appsync-events/RouteHandlerRegistry.js'; import type { Router } from '../appsync-events/Router.js'; - -// #region Shared - -// biome-ignore lint/suspicious/noExplicitAny: We intentionally use `any` here to represent any type of data and keep the logger is as flexible as possible. -type Anything = any; - -/** - * Interface for a generic logger object. - */ -type GenericLogger = { - trace?: (...content: Anything[]) => void; - debug: (...content: Anything[]) => void; - info?: (...content: Anything[]) => void; - warn: (...content: Anything[]) => void; - error: (...content: Anything[]) => void; -}; +import type { Anything, GenericLogger } from './common.js'; // #region OnPublish fn diff --git a/packages/event-handler/src/types/bedrock-agent.ts b/packages/event-handler/src/types/bedrock-agent.ts new file mode 100644 index 0000000000..f707636885 --- /dev/null +++ b/packages/event-handler/src/types/bedrock-agent.ts @@ -0,0 +1,109 @@ +import type { JSONValue } from '@aws-lambda-powertools/commons/types'; +import type { Context } from 'aws-lambda'; +import type { GenericLogger } from '../types/common.js'; + +type Configuration = { + name: string; + description: string; +}; + +type Parameter = { + name: string; + type: 'string' | 'number' | 'integer' | 'boolean' | 'array'; + value: string; +}; + +type ParameterPrimitives = string | number | boolean; + +type ParameterValue = ParameterPrimitives | Array; + +type ToolFunction> = ( + params: TParams, + options?: { + event?: BedrockAgentFunctionEvent; + context?: Context; + } +) => Promise; + +type Tool> = { + handler: ToolFunction; + config: Configuration; +}; + +type FunctionIdentifier = { + actionGroup: string; + function: string; +}; + +type FunctionInvocation = FunctionIdentifier & { + parameters?: Array; +}; + +type BedrockAgentFunctionEvent = FunctionInvocation & { + messageVersion: string; + agent: { + name: string; + id: string; + alias: string; + version: string; + }; + inputText: string; + sessionId: string; + sessionAttributes: Record; + promptSessionAttributes: Record; +}; + +type ResponseState = 'ERROR' | 'REPROMPT'; + +type TextResponseBody = { + TEXT: { + body: string; + }; +}; + +type SessionData = { + sessionAttributes?: Record; + promptSessionAttributes?: Record; +}; + +type BedrockAgentFunctionResponse = SessionData & { + messageVersion: string; + response: FunctionIdentifier & { + functionResponse: { + responseState?: ResponseState; + responseBody: TextResponseBody; + }; + }; +}; + +type ResponseOptions = FunctionIdentifier & + SessionData & { + body: string; + errorType?: ResponseState; + }; + +/** + * Options for the {@link BedrockAgentFunctionResolver} class + */ +type ResolverOptions = { + /** + * A logger instance to be used for logging debug, warning, and error messages. + * + * When no logger is provided, we'll only log warnings and errors using the global `console` object. + */ + logger?: GenericLogger; +}; + +export type { + Configuration, + Tool, + ToolFunction, + Parameter, + ParameterValue, + FunctionIdentifier, + FunctionInvocation, + BedrockAgentFunctionEvent, + BedrockAgentFunctionResponse, + ResponseOptions, + ResolverOptions, +}; diff --git a/packages/event-handler/src/types/common.ts b/packages/event-handler/src/types/common.ts new file mode 100644 index 0000000000..a3f71b397b --- /dev/null +++ b/packages/event-handler/src/types/common.ts @@ -0,0 +1,15 @@ +// biome-ignore lint/suspicious/noExplicitAny: We intentionally use `any` here to represent any type of data and keep the logger is as flexible as possible. +type Anything = any; + +/** + * Interface for a generic logger object. + */ +type GenericLogger = { + trace?: (...content: Anything[]) => void; + debug: (...content: Anything[]) => void; + info?: (...content: Anything[]) => void; + warn: (...content: Anything[]) => void; + error: (...content: Anything[]) => void; +}; + +export type { Anything, GenericLogger }; diff --git a/packages/event-handler/src/types/index.ts b/packages/event-handler/src/types/index.ts index 424189e05c..eee1ee9f8c 100644 --- a/packages/event-handler/src/types/index.ts +++ b/packages/event-handler/src/types/index.ts @@ -8,3 +8,14 @@ export type { RouteOptions, RouterOptions, } from './appsync-events.js'; + +export type { + BedrockAgentFunctionEvent, + BedrockAgentFunctionResponse, + ResolverOptions, +} from './bedrock-agent.js'; + +export type { + GenericLogger, + Anything, +} from './common.js'; diff --git a/packages/event-handler/tests/unit/AppSyncEventsResolver.test.ts b/packages/event-handler/tests/unit/appsync-events/AppSyncEventsResolver.test.ts similarity index 98% rename from packages/event-handler/tests/unit/AppSyncEventsResolver.test.ts rename to packages/event-handler/tests/unit/appsync-events/AppSyncEventsResolver.test.ts index 8da6f88c02..d75bede42f 100644 --- a/packages/event-handler/tests/unit/AppSyncEventsResolver.test.ts +++ b/packages/event-handler/tests/unit/appsync-events/AppSyncEventsResolver.test.ts @@ -3,11 +3,11 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { AppSyncEventsResolver, UnauthorizedException, -} from '../../src/appsync-events/index.js'; +} from '../../../src/appsync-events/index.js'; import { onPublishEventFactory, onSubscribeEventFactory, -} from '../helpers/factories.js'; +} from '../../helpers/factories.js'; describe('Class: AppSyncEventsResolver', () => { beforeEach(() => { diff --git a/packages/event-handler/tests/unit/RouteHandlerRegistry.test.ts b/packages/event-handler/tests/unit/appsync-events/RouteHandlerRegistry.test.ts similarity index 96% rename from packages/event-handler/tests/unit/RouteHandlerRegistry.test.ts rename to packages/event-handler/tests/unit/appsync-events/RouteHandlerRegistry.test.ts index fc0431967b..03bc8fdbfe 100644 --- a/packages/event-handler/tests/unit/RouteHandlerRegistry.test.ts +++ b/packages/event-handler/tests/unit/appsync-events/RouteHandlerRegistry.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { RouteHandlerRegistry } from '../../src/appsync-events/RouteHandlerRegistry.js'; -import type { RouteHandlerOptions } from '../../src/types/appsync-events.js'; +import { RouteHandlerRegistry } from '../../../src/appsync-events/RouteHandlerRegistry.js'; +import type { RouteHandlerOptions } from '../../../src/types/appsync-events.js'; describe('Class: RouteHandlerRegistry', () => { class MockRouteHandlerRegistry extends RouteHandlerRegistry { diff --git a/packages/event-handler/tests/unit/Router.test.ts b/packages/event-handler/tests/unit/appsync-events/Router.test.ts similarity index 97% rename from packages/event-handler/tests/unit/Router.test.ts rename to packages/event-handler/tests/unit/appsync-events/Router.test.ts index 90fa2492a1..c9fa3d9382 100644 --- a/packages/event-handler/tests/unit/Router.test.ts +++ b/packages/event-handler/tests/unit/appsync-events/Router.test.ts @@ -1,5 +1,5 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { Router } from '../../src/appsync-events/index.js'; +import { Router } from '../../../src/appsync-events/index.js'; describe('Class: Router', () => { beforeEach(() => { diff --git a/packages/event-handler/tests/unit/bedrock-agent/BedrockAgentFunctionResolver.test.ts b/packages/event-handler/tests/unit/bedrock-agent/BedrockAgentFunctionResolver.test.ts new file mode 100644 index 0000000000..f0d929a668 --- /dev/null +++ b/packages/event-handler/tests/unit/bedrock-agent/BedrockAgentFunctionResolver.test.ts @@ -0,0 +1,666 @@ +import context from '@aws-lambda-powertools/testing-utils/context'; +import type { Context } from 'aws-lambda'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { BedrockAgentFunctionResolver } from '../../../src/bedrock-agent/index.js'; +import type { + BedrockAgentFunctionEvent, + Configuration, + Parameter, + ToolFunction, +} from '../../../src/types/bedrock-agent'; + +function createEvent(functionName: string, parameters?: Parameter[]) { + return { + messageVersion: '1.0', + agent: { + name: 'agentName', + id: 'agentId', + alias: 'agentAlias', + version: '1', + }, + sessionId: 'sessionId', + inputText: 'inputText', + function: functionName, + ...(parameters == null ? {} : { parameters }), + actionGroup: 'myActionGroup', + sessionAttributes: {}, + promptSessionAttributes: {}, + }; +} + +describe('Class: BedrockAgentFunctionResolver', () => { + beforeEach(() => { + vi.unstubAllEnvs(); + }); + + it.each([ + { + name: 'null event', + invalidEvent: null, + }, + { + name: 'missing required fields', + invalidEvent: { + function: 'test-tool', + }, + }, + { + name: 'invalid parameters structure', + invalidEvent: { + function: 'test-tool', + actionGroup: 'testGroup', + messageVersion: '1.0', + agent: { + name: 'agentName', + id: 'agentId', + alias: 'agentAlias', + version: '1', + }, + inputText: 'test input', + sessionId: 'session123', + parameters: 'not an array', + sessionAttributes: {}, + promptSessionAttributes: {}, + }, + }, + ])('throws when given an invalid event: $name', async ({ invalidEvent }) => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool(async () => 'test', { + name: 'test-tool', + description: 'Test tool', + }); + + // Act & Assert + await expect(app.resolve(invalidEvent, context)).rejects.toThrow( + 'Event is not a valid BedrockAgentFunctionEvent' + ); + }); + + it('uses a default logger with only warnings if none is provided', () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool( + async (params: { arg: string }) => { + return params.arg; + }, + { + name: 'identity', + description: 'Returns its arg', + } + ); + + // Assess + expect(console.debug).not.toHaveBeenCalled(); + }); + + it('emits debug message when AWS_LAMBDA_LOG_LEVEL is set to DEBUG', () => { + // Prepare + vi.stubEnv('AWS_LAMBDA_LOG_LEVEL', 'DEBUG'); + const app = new BedrockAgentFunctionResolver(); + + app.tool( + async (params: { arg: string }) => { + return params.arg; + }, + { + name: 'identity', + description: 'Returns its arg', + } + ); + + // Assess + expect(console.debug).toHaveBeenCalled(); + }); + + it('only allows five tools to be registered', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + for (const num of [1, 2, 3, 4, 5]) { + app.tool( + async (params: { arg: string }) => { + return params.arg; + }, + { + name: `identity${num}`, + description: 'Returns its arg', + } + ); + } + + app.tool( + async (params: { a: number; b: number }) => { + return params.a + params.b; + }, + { + name: 'mult', + description: 'Multiplies two numbers', + } + ); + + const event = createEvent('mult', [ + { + name: 'a', + type: 'number', + value: '1', + }, + { + name: 'b', + type: 'number', + value: '2', + }, + ]); + + // Act + const actual = await app.resolve(event, context); + + // Assess + expect(console.warn).toHaveBeenLastCalledWith( + 'The maximum number of tools that can be registered is 5. Tool mult will not be registered.' + ); + expect(actual.response.function).toEqual('mult'); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + 'Error: tool has not been registered in handler.' + ); + }); + + it('overwrites tools with the same name and uses the latest definition', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + const event = createEvent('math', [ + { + name: 'a', + type: 'number', + value: '10', + }, + { + name: 'b', + type: 'number', + value: '2', + }, + ]); + + app.tool( + async (params: { a: number; b: number }) => { + return params.a + params.b; + }, + { + name: 'math', + description: 'Adds two numbers', + } + ); + + const addResult = await app.resolve(event, context); + expect(addResult.response.function).toEqual('math'); + expect(addResult.response.functionResponse.responseBody.TEXT.body).toEqual( + '12' + ); + + app.tool( + async (params: { a: number; b: number }) => { + return params.a * params.b; + }, + { + name: 'math', + description: 'Multiplies two numbers', + } + ); + + const multiplyResult = await app.resolve(event, context); + expect(multiplyResult.response.function).toEqual('math'); + expect( + multiplyResult.response.functionResponse.responseBody.TEXT.body + ).toEqual('20'); + }); + + it('accepts custom logger', async () => { + // Prepare + vi.stubEnv('AWS_LAMBDA_LOG_LEVEL', 'DEBUG'); + + const logger = { + debug: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + }; + const app = new BedrockAgentFunctionResolver({ logger }); + + app.tool( + async (params: { arg: string }) => { + return params.arg; + }, + { + name: 'identity', + description: 'Returns its arg', + } + ); + + app.tool( + async (params: { arg: string }) => { + return params.arg; + }, + { + name: 'identity', + description: 'Returns its arg', + } + ); + + app.tool( + async (_params) => { + throw new Error(); + }, + { + name: 'error', + description: 'errors', + } + ); + + // Act + await app.resolve(createEvent('noop'), context); + await app.resolve(createEvent('error'), context).catch(() => {}); + + // Assess + expect(logger.warn).toHaveBeenCalled(); + expect(logger.error).toHaveBeenCalled(); + expect(logger.debug).toHaveBeenCalled(); + }); + + it('tool function has access to the event variable', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool( + async (_params, options) => { + return options?.event; + }, + { + name: 'event-accessor', + description: 'Accesses the event object', + } + ); + + const event = createEvent('event-accessor'); + + // Act + const result = await app.resolve(event, context); + + // Assess + expect(result.response.function).toEqual('event-accessor'); + expect(result.response.functionResponse.responseBody.TEXT.body).toEqual( + JSON.stringify(event) + ); + }); + + it('can be invoked using the decorator pattern', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + class Lambda { + @app.tool({ name: 'hello', description: 'Says hello' }) + async helloWorld() { + return 'Hello, world!'; + } + + @app.tool({ name: 'add', description: 'Adds two numbers' }) + async add(params: { a: number; b: number }) { + const { a, b } = params; + return a + b; + } + + public async handler(event: BedrockAgentFunctionEvent, context: Context) { + return app.resolve(event, context); + } + } + + const lambda = new Lambda(); + + const addEvent = createEvent('add', [ + { + name: 'a', + type: 'number', + value: '1', + }, + { + name: 'b', + type: 'number', + value: '2', + }, + ]); + + // Act + const actual = await lambda.handler(addEvent, context); + + // Assess + expect(actual.response.function).toEqual('add'); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '3' + ); + }); + + it.each([ + { + toolFunction: async () => ({ + name: 'John Doe', + age: 30, + isActive: true, + address: { + street: '123 Main St', + city: 'Anytown', + }, + }), + toolParams: { + name: 'object', + description: 'Returns an object', + }, + expected: + '{"name":"John Doe","age":30,"isActive":true,"address":{"street":"123 Main St","city":"Anytown"}}', + }, + { + toolFunction: async () => [1, 'two', false, null], + toolParams: { + name: 'array', + description: 'Returns an array', + }, + expected: '[1,"two",false,null]', + }, + ])( + 'handles function that returns $toolParams.name', + async ({ toolFunction, toolParams, expected }) => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool(toolFunction, toolParams); + + // Act + const actual = await app.resolve(createEvent(toolParams.name), context); + + // Asses + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + expected + ); + } + ); + + it.each([ + { + toolFunction: async () => null, + toolParams: { + name: 'null', + description: 'Returns null', + }, + }, + { + toolFunction: async () => void 0, + toolParams: { + name: 'undefined', + description: 'Returns undefined', + }, + }, + ])( + 'handles functions that return $toolParams.name by returning an empty string', + async ({ toolFunction, toolParams }) => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool(toolFunction, toolParams); + + // Assess + const actual = await app.resolve(createEvent(toolParams.name), context); + + // Act + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '' + ); + } + ); + + it('correctly parses boolean parameters', async () => { + // Prepare + const toolFunction: ToolFunction<{ arg: boolean }> = async ( + params, + _options + ) => params.arg; + + const toolParams: Configuration = { + name: 'boolean', + description: 'Handles boolean parameters', + }; + + const parameters: Parameter[] = [ + { name: 'arg', type: 'boolean', value: 'true' }, + ]; + + const app = new BedrockAgentFunctionResolver(); + app.tool(toolFunction, toolParams); + + //Act + const actual = await app.resolve( + createEvent(toolParams.name, parameters), + context + ); + + // Assess + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + 'true' + ); + }); + + it('correctly parses number parameters', async () => { + // Prepare + const toolFunction: ToolFunction<{ arg: number }> = async ( + params, + _options + ) => params.arg + 10; + + const toolParams: Configuration = { + name: 'number', + description: 'Handles number parameters', + }; + + const parameters: Parameter[] = [ + { name: 'arg', type: 'number', value: '42' }, + ]; + + const app = new BedrockAgentFunctionResolver(); + app.tool(toolFunction, toolParams); + + // Act + const actual = await app.resolve( + createEvent(toolParams.name, parameters), + context + ); + + // Assess + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '52' + ); + }); + + it('correctly parses integer parameters', async () => { + // Prepare + const toolFunction: ToolFunction<{ arg: number }> = async ( + params, + _options + ) => params.arg + 10; + + const toolParams: Configuration = { + name: 'integer', + description: 'Handles integer parameters', + }; + + const parameters: Parameter[] = [ + { name: 'arg', type: 'integer', value: '37' }, + ]; + + const app = new BedrockAgentFunctionResolver(); + app.tool(toolFunction, toolParams); + + // Act + const actual = await app.resolve( + createEvent(toolParams.name, parameters), + context + ); + + // Assess + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '47' + ); + }); + + it('correctly parses string parameters', async () => { + // Prepare + const toolFunction: ToolFunction<{ arg: string }> = async ( + params, + _options + ) => `String: ${params.arg}`; + + const toolParams: Configuration = { + name: 'string', + description: 'Handles string parameters', + }; + + const parameters: Parameter[] = [ + { name: 'arg', type: 'string', value: 'hello world' }, + ]; + + const app = new BedrockAgentFunctionResolver(); + app.tool(toolFunction, toolParams); + + // Act + const actual = await app.resolve( + createEvent(toolParams.name, parameters), + context + ); + + // Assess + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '"String: hello world"' + ); + }); + + it('correctly parses array parameters', async () => { + // Prepare + const toolFunction: ToolFunction<{ arg: string }> = async ( + params, + _options + ) => `Array as string: ${params.arg}`; + + const toolParams: Configuration = { + name: 'array', + description: 'Handles array parameters (as string)', + }; + + const parameters: Parameter[] = [ + { name: 'arg', type: 'array', value: '[1,2,3]' }, + ]; + + const app = new BedrockAgentFunctionResolver(); + app.tool(toolFunction, toolParams); + + // Act + const actual = await app.resolve( + createEvent(toolParams.name, parameters), + context + ); + + // Assess + expect(actual.response.function).toEqual(toolParams.name); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + '"Array as string: [1,2,3]"' + ); + }); + + it('handles functions that throw errors', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool( + async (_params, _options) => { + throw new Error('Something went wrong'); + }, + { + name: 'error-tool', + description: 'Throws an error', + } + ); + + // Act + const actual = await app.resolve(createEvent('error-tool', []), context); + + // Assess + expect(actual.response.function).toEqual('error-tool'); + expect(actual.response.functionResponse.responseBody.TEXT.body).toEqual( + 'Error when invoking tool: Error: Something went wrong' + ); + expect(console.error).toHaveBeenCalledWith( + 'An error occurred in tool error-tool.', + new Error('Something went wrong') + ); + }); + + it('returns a fully structured BedrockAgentFunctionResponse', async () => { + // Prepare + const app = new BedrockAgentFunctionResolver(); + + app.tool( + async (params, _options) => { + return `Hello, ${params.name}!`; + }, + { + name: 'greeting', + description: 'Greets a person by name', + } + ); + + const customSessionAttrs = { + sessionAttr: '12345', + }; + + const customPromptAttrs = { + promptAttr: 'promptAttr', + }; + + const customEvent = { + ...createEvent('greeting', [ + { + name: 'name', + type: 'string', + value: 'John', + }, + ]), + actionGroup: 'actionGroup', + sessionAttributes: customSessionAttrs, + promptSessionAttributes: customPromptAttrs, + }; + + // Act + const result = await app.resolve(customEvent, context); + + // Assess + expect(result).toEqual({ + messageVersion: '1.0', + response: { + actionGroup: 'actionGroup', + function: 'greeting', + functionResponse: { + responseBody: { + TEXT: { + body: '"Hello, John!"', + }, + }, + }, + }, + sessionAttributes: customSessionAttrs, + promptSessionAttributes: customPromptAttrs, + }); + }); +});