diff --git a/examples/google/toolcall.php b/examples/google/toolcall.php new file mode 100644 index 00000000..2e47b3a5 --- /dev/null +++ b/examples/google/toolcall.php @@ -0,0 +1,31 @@ +loadEnv(dirname(__DIR__, 2).'/.env'); + +if (empty($_ENV['GOOGLE_API_KEY'])) { + echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL; + exit(1); +} + +$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']); +$llm = new Gemini(Gemini::GEMINI_2_FLASH); + +$toolbox = Toolbox::create(new Clock()); +$processor = new ChainProcessor($toolbox); +$chain = new Chain($platform, $llm, [$processor], [$processor]); + +$messages = new MessageBag(Message::ofUser('What time is it?')); +$response = $chain->call($messages); + +echo $response->getContent().\PHP_EOL; diff --git a/src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php b/src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php index bdc62cf8..0ea58691 100644 --- a/src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php +++ b/src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php @@ -35,8 +35,23 @@ protected function supportsModel(Model $model): bool */ public function normalize(mixed $data, ?string $format = null, array $context = []): array { - return [ - ['text' => $data->content], - ]; + $normalized = []; + + if (isset($data->content)) { + $normalized['text'] = $data->content; + } + + if (isset($data->toolCalls[0])) { + $normalized['functionCall'] = [ + 'id' => $data->toolCalls[0]->id, + 'name' => $data->toolCalls[0]->name, + ]; + + if ($data->toolCalls[0]->arguments) { + $normalized['functionCall']['args'] = $data->toolCalls[0]->arguments; + } + } + + return [$normalized]; } } diff --git a/src/Platform/Bridge/Google/Contract/ToolCallMessageNormalizer.php b/src/Platform/Bridge/Google/Contract/ToolCallMessageNormalizer.php new file mode 100644 index 00000000..4c182e04 --- /dev/null +++ b/src/Platform/Bridge/Google/Contract/ToolCallMessageNormalizer.php @@ -0,0 +1,56 @@ + + */ +final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface +{ + use NormalizerAwareTrait; + + protected function supportedDataClass(): string + { + return ToolCallMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param ToolCallMessage $data + * + * @return array{ + * functionResponse: array{ + * id: string, + * name: string, + * response: array + * } + * }[] + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content; + + return [[ + 'functionResponse' => array_filter([ + 'id' => $data->toolCall->id, + 'name' => $data->toolCall->name, + 'response' => \is_array($responseContent) ? $responseContent : [ + 'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses. + ], + ]), + ]]; + } +} diff --git a/src/Platform/Bridge/Google/Contract/ToolNormalizer.php b/src/Platform/Bridge/Google/Contract/ToolNormalizer.php new file mode 100644 index 00000000..50e62f5a --- /dev/null +++ b/src/Platform/Bridge/Google/Contract/ToolNormalizer.php @@ -0,0 +1,54 @@ + + * + * @phpstan-import-type JsonSchema from Factory + */ +final class ToolNormalizer extends ModelContractNormalizer +{ + protected function supportedDataClass(): string + { + return Tool::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Gemini; + } + + /** + * @param Tool $data + * + * @return array{ + * functionDeclarations: array{ + * name: string, + * description: string, + * parameters: JsonSchema|array{type: 'object'} + * }[] + * } + */ + public function normalize(mixed $data, ?string $format = null, array $context = []): array + { + $parameters = $data->parameters; + unset($parameters['additionalProperties']); + + return [ + 'functionDeclarations' => [ + [ + 'description' => $data->description, + 'name' => $data->name, + 'parameters' => $parameters, + ], + ], + ]; + } +} diff --git a/src/Platform/Bridge/Google/Gemini.php b/src/Platform/Bridge/Google/Gemini.php index 7e6d2e2e..92d96e8c 100644 --- a/src/Platform/Bridge/Google/Gemini.php +++ b/src/Platform/Bridge/Google/Gemini.php @@ -27,6 +27,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options = Capability::INPUT_MESSAGES, Capability::INPUT_IMAGE, Capability::OUTPUT_STREAMING, + Capability::TOOL_CALLING, ]; parent::__construct($name, $capabilities, $options); diff --git a/src/Platform/Bridge/Google/ModelHandler.php b/src/Platform/Bridge/Google/ModelHandler.php index fc755eb4..aed97dfa 100644 --- a/src/Platform/Bridge/Google/ModelHandler.php +++ b/src/Platform/Bridge/Google/ModelHandler.php @@ -7,9 +7,13 @@ use PhpLlm\LlmChain\Platform\Exception\RuntimeException; use PhpLlm\LlmChain\Platform\Model; use PhpLlm\LlmChain\Platform\ModelClientInterface; +use PhpLlm\LlmChain\Platform\Response\Choice; +use PhpLlm\LlmChain\Platform\Response\ChoiceResponse; use PhpLlm\LlmChain\Platform\Response\ResponseInterface as LlmResponse; use PhpLlm\LlmChain\Platform\Response\StreamResponse; use PhpLlm\LlmChain\Platform\Response\TextResponse; +use PhpLlm\LlmChain\Platform\Response\ToolCall; +use PhpLlm\LlmChain\Platform\Response\ToolCallResponse; use PhpLlm\LlmChain\Platform\ResponseConverterInterface; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface; @@ -52,6 +56,12 @@ public function request(Model $model, array|string $payload, array $options = [] $generationConfig = ['generationConfig' => $options]; unset($generationConfig['generationConfig']['stream']); + unset($generationConfig['generationConfig']['tools']); + + if (isset($options['tools'])) { + $generationConfig['tools'] = $options['tools']; + unset($options['tools']); + } return $this->httpClient->request('POST', $url, [ 'headers' => [ @@ -76,11 +86,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe $data = $response->toArray(); - if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + if (!isset($data['candidates'][0]['content']['parts'][0])) { throw new RuntimeException('Response does not contain any content'); } - return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']); + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['candidates']); + + if (1 !== \count($choices)) { + return new ChoiceResponse(...$choices); + } + + if ($choices[0]->hasToolCall()) { + return new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + return new TextResponse($choices[0]->getContent()); } private function convertStream(ResponseInterface $response): \Generator @@ -114,12 +135,68 @@ private function convertStream(ResponseInterface $response): \Generator throw new RuntimeException('Failed to decode JSON response', 0, $e); } - if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) { + /** @var Choice[] $choices */ + $choices = array_map($this->convertChoice(...), $data['candidates'] ?? []); + + if (!$choices) { continue; } - yield $data['candidates'][0]['content']['parts'][0]['text']; + if (1 !== \count($choices)) { + yield new ChoiceResponse(...$choices); + continue; + } + + if ($choices[0]->hasToolCall()) { + yield new ToolCallResponse(...$choices[0]->getToolCalls()); + } + + if ($choices[0]->hasContent()) { + yield $choices[0]->getContent(); + } } } } + + /** + * @param array{ + * finishReason?: string, + * content: array{ + * parts: array{ + * functionCall?: array{ + * id: string, + * name: string, + * args: mixed[] + * }, + * text?: string + * }[] + * } + * } $choice + */ + private function convertChoice(array $choice): Choice + { + $contentPart = $choice['content']['parts'][0] ?? []; + + if (isset($contentPart['functionCall'])) { + return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]); + } + + if (isset($contentPart['text'])) { + return new Choice($contentPart['text']); + } + + throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason'])); + } + + /** + * @param array{ + * id: string, + * name: string, + * args: mixed[] + * } $toolCall + */ + private function convertToolCall(array $toolCall): ToolCall + { + return new ToolCall($toolCall['id'] ?? '', $toolCall['name'], $toolCall['args']); + } } diff --git a/src/Platform/Bridge/Google/PlatformFactory.php b/src/Platform/Bridge/Google/PlatformFactory.php index 143edba9..0665557e 100644 --- a/src/Platform/Bridge/Google/PlatformFactory.php +++ b/src/Platform/Bridge/Google/PlatformFactory.php @@ -6,6 +6,8 @@ use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\AssistantMessageNormalizer; use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\MessageBagNormalizer; +use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer; +use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolNormalizer; use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\UserMessageNormalizer; use PhpLlm\LlmChain\Platform\Contract; use PhpLlm\LlmChain\Platform\Platform; @@ -28,6 +30,8 @@ public static function create( return new Platform([$responseHandler], [$responseHandler], Contract::create( new AssistantMessageNormalizer(), new MessageBagNormalizer(), + new ToolNormalizer(), + new ToolCallMessageNormalizer(), new UserMessageNormalizer(), )); } diff --git a/tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php b/tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php index 0283fee4..e5c57f1b 100644 --- a/tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php +++ b/tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php @@ -9,7 +9,9 @@ use PhpLlm\LlmChain\Platform\Contract; use PhpLlm\LlmChain\Platform\Message\AssistantMessage; use PhpLlm\LlmChain\Platform\Model; +use PhpLlm\LlmChain\Platform\Response\ToolCall; use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; use PHPUnit\Framework\Attributes\Small; use PHPUnit\Framework\Attributes\Test; use PHPUnit\Framework\Attributes\UsesClass; @@ -20,6 +22,7 @@ #[UsesClass(Gemini::class)] #[UsesClass(AssistantMessage::class)] #[UsesClass(Model::class)] +#[UsesClass(ToolCall::class)] final class AssistantMessageNormalizerTest extends TestCase { #[Test] @@ -41,14 +44,33 @@ public function getSupportedTypes(): void self::assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null)); } + #[DataProvider('normalizeDataProvider')] #[Test] - public function normalize(): void + public function normalize(AssistantMessage $message, array $expectedOutput): void { $normalizer = new AssistantMessageNormalizer(); - $message = new AssistantMessage('Great to meet you. What would you like to know?'); $normalized = $normalizer->normalize($message); - self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized); + self::assertSame($expectedOutput, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'assistant message' => [ + new AssistantMessage('Great to meet you. What would you like to know?'), + [['text' => 'Great to meet you. What would you like to know?']], + ]; + yield 'function call' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]), + [['functionCall' => ['id' => 'id1', 'name' => 'name1', 'args' => ['arg1' => '123']]]], + ]; + yield 'function call without parameters' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1')]), + [['functionCall' => ['id' => 'id1', 'name' => 'name1']]], + ]; } } diff --git a/tests/Platform/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php b/tests/Platform/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php new file mode 100644 index 00000000..8f6cf798 --- /dev/null +++ b/tests/Platform/Bridge/Google/Contract/ToolCallMessageNormalizerTest.php @@ -0,0 +1,95 @@ +supportsNormalization(new ToolCallMessage(new ToolCall('', '', []), ''), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a tool call')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new ToolCallMessageNormalizer(); + + $expected = [ + ToolCallMessage::class => true, + ]; + + self::assertSame($expected, $normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('normalizeDataProvider')] + public function normalize(ToolCallMessage $message, array $expected): void + { + $normalizer = new ToolCallMessageNormalizer(); + + $normalized = $normalizer->normalize($message); + + self::assertEquals($expected, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'scalar' => [ + new ToolCallMessage( + new ToolCall('id1', 'name1', ['foo' => 'bar']), + 'true', + ), + [[ + 'functionResponse' => [ + 'id' => 'id1', + 'name' => 'name1', + 'response' => ['rawResponse' => 'true'], + ], + ]], + ]; + + yield 'structured response' => [ + new ToolCallMessage( + new ToolCall('id1', 'name1', ['foo' => 'bar']), + '{"structured":"response"}', + ), + [[ + 'functionResponse' => [ + 'id' => 'id1', + 'name' => 'name1', + 'response' => ['structured' => 'response'], + ], + ]], + ]; + } +} diff --git a/tests/Platform/Bridge/Google/Contract/ToolNormalizerTest.php b/tests/Platform/Bridge/Google/Contract/ToolNormalizerTest.php new file mode 100644 index 00000000..7dccb164 --- /dev/null +++ b/tests/Platform/Bridge/Google/Contract/ToolNormalizerTest.php @@ -0,0 +1,131 @@ +supportsNormalization(new Tool(new ExecutionReference(ToolNoParams::class), 'test', 'test'), context: [ + Contract::CONTEXT_MODEL => new Gemini(), + ])); + self::assertFalse($normalizer->supportsNormalization('not a tool')); + } + + #[Test] + public function getSupportedTypes(): void + { + $normalizer = new ToolNormalizer(); + + $expected = [ + Tool::class => true, + ]; + + self::assertSame($expected, $normalizer->getSupportedTypes(null)); + } + + #[Test] + #[DataProvider('normalizeDataProvider')] + public function normalize(Tool $tool, array $expected): void + { + $normalizer = new ToolNormalizer(); + + $normalized = $normalizer->normalize($tool); + + self::assertEquals($expected, $normalized); + } + + /** + * @return iterable + */ + public static function normalizeDataProvider(): iterable + { + yield 'call with params' => [ + new Tool( + new ExecutionReference(ToolRequiredParams::class, 'bar'), + 'tool_required_params', + 'A tool with required parameters', + [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'Text parameter', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'Number parameter', + ], + ], + 'required' => ['text', 'number'], + 'additionalProperties' => false, + ], + ), + [ + 'functionDeclarations' => [ + [ + 'description' => 'A tool with required parameters', + 'name' => 'tool_required_params', + 'parameters' => [ + 'type' => 'object', + 'properties' => [ + 'text' => [ + 'type' => 'string', + 'description' => 'Text parameter', + ], + 'number' => [ + 'type' => 'integer', + 'description' => 'Number parameter', + ], + ], + 'required' => ['text', 'number'], + ], + ], + ], + ], + ]; + + yield 'call without params' => [ + new Tool( + new ExecutionReference(ToolNoParams::class, 'bar'), + 'tool_no_params', + 'A tool without parameters', + null, + ), + [ + 'functionDeclarations' => [ + [ + 'description' => 'A tool without parameters', + 'name' => 'tool_no_params', + 'parameters' => null, + ], + ], + ], + ]; + } +}