From c6ba1ceddb3bc4d0b0a9d2b12a88392c56d373ab Mon Sep 17 00:00:00 2001 From: Christopher Hertel Date: Mon, 5 May 2025 00:06:42 +0200 Subject: [PATCH] refactor: add extension point to payload contract handling --- src/Bridge/OpenAI/DallE/ContractExtension.php | 31 ++++++++ src/Bridge/OpenAI/DallE/ModelClient.php | 7 +- .../OpenAI/Embeddings/ContractExtension.php | 31 ++++++++ src/Bridge/OpenAI/Embeddings/ModelClient.php | 7 +- .../OpenAI/Embeddings/ResponseConverter.php | 2 +- src/Bridge/OpenAI/GPT/ModelClient.php | 9 +-- src/Bridge/OpenAI/GPT/ResponseConverter.php | 2 +- src/Bridge/OpenAI/PlatformFactory.php | 9 +++ .../OpenAI/Whisper/ContractExtension.php | 32 +++++++++ src/Bridge/OpenAI/Whisper/ModelClient.php | 14 ++-- .../OpenAI/Whisper/ResponseConverter.php | 5 +- src/Model/Message/AssistantMessage.php | 2 +- src/Model/Message/SystemMessage.php | 2 +- src/Model/Message/ToolCallMessage.php | 2 +- src/Model/Message/UserMessage.php | 2 +- src/Platform.php | 30 ++++---- src/Platform/Contract.php | 41 +++++++++++ src/Platform/Contract/Extension.php | 22 ++++++ src/Platform/Contract/InputNormalizer.php | 70 +++++++++++++++++++ src/Platform/ModelClient.php | 11 ++- src/Platform/ResponseConverter.php | 5 +- 21 files changed, 281 insertions(+), 55 deletions(-) create mode 100644 src/Bridge/OpenAI/DallE/ContractExtension.php create mode 100644 src/Bridge/OpenAI/Embeddings/ContractExtension.php create mode 100644 src/Bridge/OpenAI/Whisper/ContractExtension.php create mode 100644 src/Platform/Contract.php create mode 100644 src/Platform/Contract/Extension.php create mode 100644 src/Platform/Contract/InputNormalizer.php diff --git a/src/Bridge/OpenAI/DallE/ContractExtension.php b/src/Bridge/OpenAI/DallE/ContractExtension.php new file mode 100644 index 0000000..e975889 --- /dev/null +++ b/src/Bridge/OpenAI/DallE/ContractExtension.php @@ -0,0 +1,31 @@ + 'handleInput', + ]; + } + + public function handleInput(string $input): array + { + return [ + 'prompt' => $input, + ]; + } +} diff --git a/src/Bridge/OpenAI/DallE/ModelClient.php b/src/Bridge/OpenAI/DallE/ModelClient.php index 39552d9..32f7d4e 100644 --- a/src/Bridge/OpenAI/DallE/ModelClient.php +++ b/src/Bridge/OpenAI/DallE/ModelClient.php @@ -28,18 +28,17 @@ public function __construct( Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); } - public function supports(Model $model, array|string|object $input): bool + public function supports(Model $model): bool { return $model instanceof DallE; } - public function request(Model $model, object|array|string $input, array $options = []): HttpResponse + public function request(Model $model, array $payload, array $options = []): HttpResponse { return $this->httpClient->request('POST', 'https://api.openai.com/v1/images/generations', [ 'auth_bearer' => $this->apiKey, - 'json' => \array_merge($options, [ + 'json' => \array_merge($options, $payload, [ 'model' => $model->getName(), - 'prompt' => $input, ]), ]); } diff --git a/src/Bridge/OpenAI/Embeddings/ContractExtension.php b/src/Bridge/OpenAI/Embeddings/ContractExtension.php new file mode 100644 index 0000000..3143ba3 --- /dev/null +++ b/src/Bridge/OpenAI/Embeddings/ContractExtension.php @@ -0,0 +1,31 @@ + 'handleInput', + ]; + } + + public function handleInput(string $input): array + { + return [ + 'input' => $input, + ]; + } +} diff --git a/src/Bridge/OpenAI/Embeddings/ModelClient.php b/src/Bridge/OpenAI/Embeddings/ModelClient.php index a1bbbd8..4c92fed 100644 --- a/src/Bridge/OpenAI/Embeddings/ModelClient.php +++ b/src/Bridge/OpenAI/Embeddings/ModelClient.php @@ -22,18 +22,17 @@ public function __construct( Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); } - public function supports(Model $model, array|string|object $input): bool + public function supports(Model $model): bool { return $model instanceof Embeddings; } - public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface + public function request(Model $model, array $payload, array $options = []): ResponseInterface { return $this->httpClient->request('POST', 'https://api.openai.com/v1/embeddings', [ 'auth_bearer' => $this->apiKey, - 'json' => array_merge($model->getOptions(), $options, [ + 'json' => array_merge($options, $payload, [ 'model' => $model->getName(), - 'input' => $input, ]), ]); } diff --git a/src/Bridge/OpenAI/Embeddings/ResponseConverter.php b/src/Bridge/OpenAI/Embeddings/ResponseConverter.php index d158ad9..4fb9cd5 100644 --- a/src/Bridge/OpenAI/Embeddings/ResponseConverter.php +++ b/src/Bridge/OpenAI/Embeddings/ResponseConverter.php @@ -14,7 +14,7 @@ final class ResponseConverter implements PlatformResponseConverter { - public function supports(Model $model, array|string|object $input): bool + public function supports(Model $model): bool { return $model instanceof Embeddings; } diff --git a/src/Bridge/OpenAI/GPT/ModelClient.php b/src/Bridge/OpenAI/GPT/ModelClient.php index 433d3d7..61376c3 100644 --- a/src/Bridge/OpenAI/GPT/ModelClient.php +++ b/src/Bridge/OpenAI/GPT/ModelClient.php @@ -18,25 +18,26 @@ public function __construct( HttpClientInterface $httpClient, - #[\SensitiveParameter] private string $apiKey, + #[\SensitiveParameter] + private string $apiKey, ) { $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".'); } - public function supports(Model $model, array|string|object $input): bool + public function supports(Model $model): bool { return $model instanceof GPT; } - public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface + public function request(Model $model, array $payload, array $options = []): ResponseInterface { return $this->httpClient->request('POST', 'https://api.openai.com/v1/chat/completions', [ 'auth_bearer' => $this->apiKey, 'json' => array_merge($options, [ 'model' => $model->getName(), - 'messages' => $input, + 'messages' => $payload, ]), ]); } diff --git a/src/Bridge/OpenAI/GPT/ResponseConverter.php b/src/Bridge/OpenAI/GPT/ResponseConverter.php index 247003e..9622d90 100644 --- a/src/Bridge/OpenAI/GPT/ResponseConverter.php +++ b/src/Bridge/OpenAI/GPT/ResponseConverter.php @@ -24,7 +24,7 @@ final class ResponseConverter implements PlatformResponseConverter { - public function supports(Model $model, array|string|object $input): bool + public function supports(Model $model): bool { return $model instanceof GPT; } diff --git a/src/Bridge/OpenAI/PlatformFactory.php b/src/Bridge/OpenAI/PlatformFactory.php index dbaf8df..65a8613 100644 --- a/src/Bridge/OpenAI/PlatformFactory.php +++ b/src/Bridge/OpenAI/PlatformFactory.php @@ -4,14 +4,18 @@ namespace PhpLlm\LlmChain\Bridge\OpenAI; +use PhpLlm\LlmChain\Bridge\OpenAI\DallE\ContractExtension as DallEContractExtension; use PhpLlm\LlmChain\Bridge\OpenAI\DallE\ModelClient as DallEModelClient; +use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ContractExtension as EmbeddingsContractExtension; use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ModelClient as EmbeddingsModelClient; use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ResponseConverter as EmbeddingsResponseConverter; use PhpLlm\LlmChain\Bridge\OpenAI\GPT\ModelClient as GPTModelClient; use PhpLlm\LlmChain\Bridge\OpenAI\GPT\ResponseConverter as GPTResponseConverter; +use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ContractExtension as WhisperContractExtension; use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ModelClient as WhisperModelClient; use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ResponseConverter as WhisperResponseConverter; use PhpLlm\LlmChain\Platform; +use PhpLlm\LlmChain\Platform\Contract; use Symfony\Component\HttpClient\EventSourceHttpClient; use Symfony\Contracts\HttpClient\HttpClientInterface; @@ -39,6 +43,11 @@ public static function create( $dallEModelClient, new WhisperResponseConverter(), ], + Contract::create( + new DallEContractExtension(), + new EmbeddingsContractExtension(), + new WhisperContractExtension(), + ), ); } } diff --git a/src/Bridge/OpenAI/Whisper/ContractExtension.php b/src/Bridge/OpenAI/Whisper/ContractExtension.php new file mode 100644 index 0000000..07161d0 --- /dev/null +++ b/src/Bridge/OpenAI/Whisper/ContractExtension.php @@ -0,0 +1,32 @@ + 'handleAudioInput', + ]; + } + + public function handleAudioInput(Audio $audio): array + { + return [ + 'file' => $audio->asResource(), + ]; + } +} diff --git a/src/Bridge/OpenAI/Whisper/ModelClient.php b/src/Bridge/OpenAI/Whisper/ModelClient.php index ca512b4..ca32261 100644 --- a/src/Bridge/OpenAI/Whisper/ModelClient.php +++ b/src/Bridge/OpenAI/Whisper/ModelClient.php @@ -5,7 +5,6 @@ namespace PhpLlm\LlmChain\Bridge\OpenAI\Whisper; use PhpLlm\LlmChain\Bridge\OpenAI\Whisper; -use PhpLlm\LlmChain\Model\Message\Content\Audio; use PhpLlm\LlmChain\Model\Model; use PhpLlm\LlmChain\Platform\ModelClient as BaseModelClient; use Symfony\Contracts\HttpClient\HttpClientInterface; @@ -22,22 +21,17 @@ public function __construct( Assert::stringNotEmpty($apiKey, 'The API key must not be empty.'); } - public function supports(Model $model, object|array|string $input): bool + public function supports(Model $model): bool { - return $model instanceof Whisper && $input instanceof Audio; + return $model instanceof Whisper; } - public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface + public function request(Model $model, array $payload, array $options = []): ResponseInterface { - assert($input instanceof Audio); - return $this->httpClient->request('POST', 'https://api.openai.com/v1/audio/transcriptions', [ 'auth_bearer' => $this->apiKey, 'headers' => ['Content-Type' => 'multipart/form-data'], - 'body' => array_merge($options, $model->getOptions(), [ - 'model' => $model->getName(), - 'file' => $input->asResource(), - ]), + 'body' => array_merge($options, $payload, ['model' => $model->getName()]), ]); } } diff --git a/src/Bridge/OpenAI/Whisper/ResponseConverter.php b/src/Bridge/OpenAI/Whisper/ResponseConverter.php index 7cfdf3c..3a3b5c7 100644 --- a/src/Bridge/OpenAI/Whisper/ResponseConverter.php +++ b/src/Bridge/OpenAI/Whisper/ResponseConverter.php @@ -5,7 +5,6 @@ namespace PhpLlm\LlmChain\Bridge\OpenAI\Whisper; use PhpLlm\LlmChain\Bridge\OpenAI\Whisper; -use PhpLlm\LlmChain\Model\Message\Content\Audio; use PhpLlm\LlmChain\Model\Model; use PhpLlm\LlmChain\Model\Response\ResponseInterface as LlmResponse; use PhpLlm\LlmChain\Model\Response\TextResponse; @@ -14,9 +13,9 @@ final class ResponseConverter implements BaseResponseConverter { - public function supports(Model $model, object|array|string $input): bool + public function supports(Model $model): bool { - return $model instanceof Whisper && $input instanceof Audio; + return $model instanceof Whisper; } public function convert(HttpResponse $response, array $options = []): LlmResponse diff --git a/src/Model/Message/AssistantMessage.php b/src/Model/Message/AssistantMessage.php index aabca47..472d660 100644 --- a/src/Model/Message/AssistantMessage.php +++ b/src/Model/Message/AssistantMessage.php @@ -37,7 +37,7 @@ public function hasToolCalls(): bool public function jsonSerialize(): array { $array = [ - 'role' => Role::Assistant, + 'role' => Role::Assistant->value, ]; if (null !== $this->content) { diff --git a/src/Model/Message/SystemMessage.php b/src/Model/Message/SystemMessage.php index b914c3b..31a6403 100644 --- a/src/Model/Message/SystemMessage.php +++ b/src/Model/Message/SystemMessage.php @@ -24,7 +24,7 @@ public function getRole(): Role public function jsonSerialize(): array { return [ - 'role' => Role::System, + 'role' => Role::System->value, 'content' => $this->content, ]; } diff --git a/src/Model/Message/ToolCallMessage.php b/src/Model/Message/ToolCallMessage.php index 20a9767..c3f9eb1 100644 --- a/src/Model/Message/ToolCallMessage.php +++ b/src/Model/Message/ToolCallMessage.php @@ -29,7 +29,7 @@ public function getRole(): Role public function jsonSerialize(): array { return [ - 'role' => Role::ToolCall, + 'role' => Role::ToolCall->value, 'content' => $this->content, 'tool_call_id' => $this->toolCall->id, ]; diff --git a/src/Model/Message/UserMessage.php b/src/Model/Message/UserMessage.php index ce6d2b2..79f1c06 100644 --- a/src/Model/Message/UserMessage.php +++ b/src/Model/Message/UserMessage.php @@ -58,7 +58,7 @@ public function hasImageContent(): bool */ public function jsonSerialize(): array { - $array = ['role' => Role::User]; + $array = ['role' => Role::User->value]; if (1 === count($this->content) && $this->content[0] instanceof Text) { $array['content'] = $this->content[0]->text; diff --git a/src/Platform.php b/src/Platform.php index 542f924..0572a88 100644 --- a/src/Platform.php +++ b/src/Platform.php @@ -8,6 +8,7 @@ use PhpLlm\LlmChain\Model\Model; use PhpLlm\LlmChain\Model\Response\AsyncResponse; use PhpLlm\LlmChain\Model\Response\ResponseInterface; +use PhpLlm\LlmChain\Platform\Contract; use PhpLlm\LlmChain\Platform\ModelClient; use PhpLlm\LlmChain\Platform\ResponseConverter; use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; @@ -28,30 +29,34 @@ * @param iterable $modelClients * @param iterable $responseConverter */ - public function __construct(iterable $modelClients, iterable $responseConverter) - { + public function __construct( + iterable $modelClients, + iterable $responseConverter, + private Contract $contract = new Contract([]), + ) { $this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients; $this->responseConverter = $responseConverter instanceof \Traversable ? iterator_to_array($responseConverter) : $responseConverter; } public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface { + $payload = $this->contract->convertRequestPayload($input, $model); $options = array_merge($model->getOptions(), $options); - $response = $this->doRequest($model, $input, $options); + $response = $this->doRequest($model, $payload, $options); - return $this->convertResponse($model, $input, $response, $options); + return $this->convertResponse($model, $response, $options); } /** - * @param array|string|object $input - * @param array $options + * @param array $payload + * @param array $options */ - private function doRequest(Model $model, array|string|object $input, array $options = []): HttpResponse + private function doRequest(Model $model, array $payload, array $options = []): HttpResponse { foreach ($this->modelClients as $modelClient) { - if ($modelClient->supports($model, $input)) { - return $modelClient->request($model, $input, $options); + if ($modelClient->supports($model)) { + return $modelClient->request($model, $payload, $options); } } @@ -59,13 +64,12 @@ private function doRequest(Model $model, array|string|object $input, array $opti } /** - * @param array|string|object $input - * @param array $options + * @param array $options */ - private function convertResponse(Model $model, object|array|string $input, HttpResponse $response, array $options): ResponseInterface + private function convertResponse(Model $model, HttpResponse $response, array $options): ResponseInterface { foreach ($this->responseConverter as $responseConverter) { - if ($responseConverter->supports($model, $input)) { + if ($responseConverter->supports($model)) { return new AsyncResponse($responseConverter, $response, $options); } } diff --git a/src/Platform/Contract.php b/src/Platform/Contract.php new file mode 100644 index 0000000..261f371 --- /dev/null +++ b/src/Platform/Contract.php @@ -0,0 +1,41 @@ +normalizer->normalize($input, null, [ + self::MODEL => $model, + ]); + } + + public static function create(Extension ...$extensions): self + { + return new self( + new Serializer( + [new InputNormalizer($extensions), new JsonSerializableNormalizer(), new ObjectNormalizer()], + [new JsonEncoder()] + ) + ); + } +} diff --git a/src/Platform/Contract/Extension.php b/src/Platform/Contract/Extension.php new file mode 100644 index 0000000..a0e4dd4 --- /dev/null +++ b/src/Platform/Contract/Extension.php @@ -0,0 +1,22 @@ + + */ + public function registerTypes(): array; +} diff --git a/src/Platform/Contract/InputNormalizer.php b/src/Platform/Contract/InputNormalizer.php new file mode 100644 index 0000000..01ae3f7 --- /dev/null +++ b/src/Platform/Contract/InputNormalizer.php @@ -0,0 +1,70 @@ +extensions = $extensions instanceof \Traversable ? iterator_to_array($extensions) : $extensions; + } + + public function supportsNormalization(mixed $data, ?string $format = null, array $context = []): bool + { + if (!isset($context[Contract::MODEL]) || !$context[Contract::MODEL] instanceof Model) { + return false; + } + + try { + $this->getHandler($context[Contract::MODEL], $data); + } catch (RuntimeException) { + return false; + } + + return true; + } + + public function normalize(mixed $data, ?string $format = null, array $context = []): array|string|int|float|bool|\ArrayObject|null + { + return ($this->getHandler($context[Contract::MODEL], $data))($data); + } + + public function getSupportedTypes(?string $format): array + { + return [ + '*' => true, + ]; + } + + private function getHandler(Model $model, mixed $data): \Closure + { + foreach ($this->extensions as $extension) { + if ($extension->supports($model)) { + $types = $extension->registerTypes(); + foreach ($types as $type => $handler) { + if (is_subclass_of($data, $type) + || (is_string($data) && 'string' === $type) + || (is_int($data) && 'int' === $type) + || (is_float($data) && 'float' === $type) + || $data instanceof $type) { + return $extension->{$handler}(...); + } + } + } + } + + throw new RuntimeException('No handler found for the data.'); + } +} diff --git a/src/Platform/ModelClient.php b/src/Platform/ModelClient.php index 09881fc..b587e41 100644 --- a/src/Platform/ModelClient.php +++ b/src/Platform/ModelClient.php @@ -9,14 +9,11 @@ interface ModelClient { - /** - * @param array|string|object $input - */ - public function supports(Model $model, array|string|object $input): bool; + public function supports(Model $model): bool; /** - * @param array|string|object $input - * @param array $options + * @param array $payload + * @param array $options */ - public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface; + public function request(Model $model, array $payload, array $options = []): ResponseInterface; } diff --git a/src/Platform/ResponseConverter.php b/src/Platform/ResponseConverter.php index a6a0326..78d62c5 100644 --- a/src/Platform/ResponseConverter.php +++ b/src/Platform/ResponseConverter.php @@ -10,10 +10,7 @@ interface ResponseConverter { - /** - * @param array|string|object $input - */ - public function supports(Model $model, array|string|object $input): bool; + public function supports(Model $model): bool; /** * @param array $options