Skip to content

refactor: add extension point to payload contract handling #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/Bridge/OpenAI/DallE/ContractExtension.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\OpenAI\DallE;

use PhpLlm\LlmChain\Bridge\OpenAI\DallE;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Platform\Contract\Extension;

final class ContractExtension implements Extension
{
public function supports(Model $model): bool
{
return $model instanceof DallE;
}

public function registerTypes(): array
{
return [
'string' => 'handleInput',
];
}
Comment on lines +18 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't we do this using an attribute like the event listeners do?


public function handleInput(string $input): array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public function handleInput(string $input): array
#[AsExtension('string')]
public function handleInput(string $input): array

idea

{
return [
'prompt' => $input,
];
}
}
7 changes: 3 additions & 4 deletions src/Bridge/OpenAI/DallE/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,17 @@ public function __construct(
Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".');
}

public function supports(Model $model, array|string|object $input): bool
public function supports(Model $model): bool
{
return $model instanceof DallE;
}

public function request(Model $model, object|array|string $input, array $options = []): HttpResponse
public function request(Model $model, array $payload, array $options = []): HttpResponse
{
return $this->httpClient->request('POST', 'https://api.openai.com/v1/images/generations', [
'auth_bearer' => $this->apiKey,
'json' => \array_merge($options, [
'json' => \array_merge($options, $payload, [
'model' => $model->getName(),
'prompt' => $input,
]),
]);
}
Expand Down
31 changes: 31 additions & 0 deletions src/Bridge/OpenAI/Embeddings/ContractExtension.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\OpenAI\Embeddings;

use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Platform\Contract\Extension;

final class ContractExtension implements Extension
{
public function supports(Model $model): bool
{
return $model instanceof Embeddings;
}

public function registerTypes(): array
{
return [
'string' => 'handleInput',
];
}

public function handleInput(string $input): array
{
return [
'input' => $input,
];
}
}
7 changes: 3 additions & 4 deletions src/Bridge/OpenAI/Embeddings/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@ public function __construct(
Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".');
}

public function supports(Model $model, array|string|object $input): bool
public function supports(Model $model): bool
{
return $model instanceof Embeddings;
}

public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface
public function request(Model $model, array $payload, array $options = []): ResponseInterface
{
return $this->httpClient->request('POST', 'https://api.openai.com/v1/embeddings', [
'auth_bearer' => $this->apiKey,
'json' => array_merge($model->getOptions(), $options, [
'json' => array_merge($options, $payload, [
'model' => $model->getName(),
'input' => $input,
]),
]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Bridge/OpenAI/Embeddings/ResponseConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

final class ResponseConverter implements PlatformResponseConverter
{
public function supports(Model $model, array|string|object $input): bool
public function supports(Model $model): bool
{
return $model instanceof Embeddings;
}
Expand Down
9 changes: 5 additions & 4 deletions src/Bridge/OpenAI/GPT/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,26 @@

public function __construct(
HttpClientInterface $httpClient,
#[\SensitiveParameter] private string $apiKey,
#[\SensitiveParameter]
private string $apiKey,
) {
$this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
Assert::stringNotEmpty($apiKey, 'The API key must not be empty.');
Assert::startsWith($apiKey, 'sk-', 'The API key must start with "sk-".');
}

public function supports(Model $model, array|string|object $input): bool
public function supports(Model $model): bool
{
return $model instanceof GPT;
}

public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface
public function request(Model $model, array $payload, array $options = []): ResponseInterface
{
return $this->httpClient->request('POST', 'https://api.openai.com/v1/chat/completions', [
'auth_bearer' => $this->apiKey,
'json' => array_merge($options, [
'model' => $model->getName(),
'messages' => $input,
'messages' => $payload,
]),
]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Bridge/OpenAI/GPT/ResponseConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

final class ResponseConverter implements PlatformResponseConverter
{
public function supports(Model $model, array|string|object $input): bool
public function supports(Model $model): bool
{
return $model instanceof GPT;
}
Expand Down
9 changes: 9 additions & 0 deletions src/Bridge/OpenAI/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@

namespace PhpLlm\LlmChain\Bridge\OpenAI;

use PhpLlm\LlmChain\Bridge\OpenAI\DallE\ContractExtension as DallEContractExtension;
use PhpLlm\LlmChain\Bridge\OpenAI\DallE\ModelClient as DallEModelClient;
use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ContractExtension as EmbeddingsContractExtension;
use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ModelClient as EmbeddingsModelClient;
use PhpLlm\LlmChain\Bridge\OpenAI\Embeddings\ResponseConverter as EmbeddingsResponseConverter;
use PhpLlm\LlmChain\Bridge\OpenAI\GPT\ModelClient as GPTModelClient;
use PhpLlm\LlmChain\Bridge\OpenAI\GPT\ResponseConverter as GPTResponseConverter;
use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ContractExtension as WhisperContractExtension;
use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ModelClient as WhisperModelClient;
use PhpLlm\LlmChain\Bridge\OpenAI\Whisper\ResponseConverter as WhisperResponseConverter;
use PhpLlm\LlmChain\Platform;
use PhpLlm\LlmChain\Platform\Contract;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\HttpClientInterface;

Expand Down Expand Up @@ -39,6 +43,11 @@ public static function create(
$dallEModelClient,
new WhisperResponseConverter(),
],
Contract::create(
new DallEContractExtension(),
new EmbeddingsContractExtension(),
new WhisperContractExtension(),
),
);
}
}
32 changes: 32 additions & 0 deletions src/Bridge/OpenAI/Whisper/ContractExtension.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Bridge\OpenAI\Whisper;

use PhpLlm\LlmChain\Bridge\OpenAI\Whisper;
use PhpLlm\LlmChain\Model\Message\Content\Audio;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Platform\Contract\Extension;

final class ContractExtension implements Extension
{
public function supports(Model $model): bool
{
return $model instanceof Whisper;
}

public function registerTypes(): array
{
return [
Audio::class => 'handleAudioInput',
];
}

public function handleAudioInput(Audio $audio): array
{
return [
'file' => $audio->asResource(),
];
}
}
14 changes: 4 additions & 10 deletions src/Bridge/OpenAI/Whisper/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
namespace PhpLlm\LlmChain\Bridge\OpenAI\Whisper;

use PhpLlm\LlmChain\Bridge\OpenAI\Whisper;
use PhpLlm\LlmChain\Model\Message\Content\Audio;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Platform\ModelClient as BaseModelClient;
use Symfony\Contracts\HttpClient\HttpClientInterface;
Expand All @@ -22,22 +21,17 @@ public function __construct(
Assert::stringNotEmpty($apiKey, 'The API key must not be empty.');
}

public function supports(Model $model, object|array|string $input): bool
public function supports(Model $model): bool
{
return $model instanceof Whisper && $input instanceof Audio;
return $model instanceof Whisper;
}

public function request(Model $model, object|array|string $input, array $options = []): ResponseInterface
public function request(Model $model, array $payload, array $options = []): ResponseInterface
{
assert($input instanceof Audio);

return $this->httpClient->request('POST', 'https://api.openai.com/v1/audio/transcriptions', [
'auth_bearer' => $this->apiKey,
'headers' => ['Content-Type' => 'multipart/form-data'],
'body' => array_merge($options, $model->getOptions(), [
'model' => $model->getName(),
'file' => $input->asResource(),
]),
'body' => array_merge($options, $payload, ['model' => $model->getName()]),
]);
}
}
5 changes: 2 additions & 3 deletions src/Bridge/OpenAI/Whisper/ResponseConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
namespace PhpLlm\LlmChain\Bridge\OpenAI\Whisper;

use PhpLlm\LlmChain\Bridge\OpenAI\Whisper;
use PhpLlm\LlmChain\Model\Message\Content\Audio;
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Model\Response\ResponseInterface as LlmResponse;
use PhpLlm\LlmChain\Model\Response\TextResponse;
Expand All @@ -14,9 +13,9 @@

final class ResponseConverter implements BaseResponseConverter
{
public function supports(Model $model, object|array|string $input): bool
public function supports(Model $model): bool
{
return $model instanceof Whisper && $input instanceof Audio;
return $model instanceof Whisper;
}

public function convert(HttpResponse $response, array $options = []): LlmResponse
Expand Down
2 changes: 1 addition & 1 deletion src/Model/Message/AssistantMessage.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public function hasToolCalls(): bool
public function jsonSerialize(): array
{
$array = [
'role' => Role::Assistant,
'role' => Role::Assistant->value,
];

if (null !== $this->content) {
Expand Down
2 changes: 1 addition & 1 deletion src/Model/Message/SystemMessage.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public function getRole(): Role
public function jsonSerialize(): array
{
return [
'role' => Role::System,
'role' => Role::System->value,
'content' => $this->content,
];
}
Expand Down
2 changes: 1 addition & 1 deletion src/Model/Message/ToolCallMessage.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public function getRole(): Role
public function jsonSerialize(): array
{
return [
'role' => Role::ToolCall,
'role' => Role::ToolCall->value,
'content' => $this->content,
'tool_call_id' => $this->toolCall->id,
];
Expand Down
2 changes: 1 addition & 1 deletion src/Model/Message/UserMessage.php
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public function hasImageContent(): bool
*/
public function jsonSerialize(): array
{
$array = ['role' => Role::User];
$array = ['role' => Role::User->value];
if (1 === count($this->content) && $this->content[0] instanceof Text) {
$array['content'] = $this->content[0]->text;

Expand Down
30 changes: 17 additions & 13 deletions src/Platform.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use PhpLlm\LlmChain\Model\Model;
use PhpLlm\LlmChain\Model\Response\AsyncResponse;
use PhpLlm\LlmChain\Model\Response\ResponseInterface;
use PhpLlm\LlmChain\Platform\Contract;
use PhpLlm\LlmChain\Platform\ModelClient;
use PhpLlm\LlmChain\Platform\ResponseConverter;
use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse;
Expand All @@ -28,44 +29,47 @@
* @param iterable<ModelClient> $modelClients
* @param iterable<ResponseConverter> $responseConverter
*/
public function __construct(iterable $modelClients, iterable $responseConverter)
{
public function __construct(
iterable $modelClients,
iterable $responseConverter,
private Contract $contract = new Contract([]),
) {
$this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients;
$this->responseConverter = $responseConverter instanceof \Traversable ? iterator_to_array($responseConverter) : $responseConverter;
}

public function request(Model $model, array|string|object $input, array $options = []): ResponseInterface
{
$payload = $this->contract->convertRequestPayload($input, $model);
$options = array_merge($model->getOptions(), $options);

$response = $this->doRequest($model, $input, $options);
$response = $this->doRequest($model, $payload, $options);

return $this->convertResponse($model, $input, $response, $options);
return $this->convertResponse($model, $response, $options);
}

/**
* @param array<mixed>|string|object $input
* @param array<string, mixed> $options
* @param array<string, mixed> $payload
* @param array<string, mixed> $options
*/
private function doRequest(Model $model, array|string|object $input, array $options = []): HttpResponse
private function doRequest(Model $model, array $payload, array $options = []): HttpResponse
{
foreach ($this->modelClients as $modelClient) {
if ($modelClient->supports($model, $input)) {
return $modelClient->request($model, $input, $options);
if ($modelClient->supports($model)) {
return $modelClient->request($model, $payload, $options);
}
}

throw new RuntimeException('No response factory registered for model "'.$model::class.'" with given input.');
}

/**
* @param array<mixed>|string|object $input
* @param array<string, mixed> $options
* @param array<string, mixed> $options
*/
private function convertResponse(Model $model, object|array|string $input, HttpResponse $response, array $options): ResponseInterface
private function convertResponse(Model $model, HttpResponse $response, array $options): ResponseInterface
{
foreach ($this->responseConverter as $responseConverter) {
if ($responseConverter->supports($model, $input)) {
if ($responseConverter->supports($model)) {
return new AsyncResponse($responseConverter, $response, $options);
}
}
Expand Down
Loading
Loading