From b62cac861c7d87d3834eff08521850b261ef94de Mon Sep 17 00:00:00 2001 From: David Gustys Date: Fri, 18 Oct 2024 10:50:34 +0300 Subject: [PATCH 1/2] feat: Add system promt support for beta --- src/Client.php | 18 ++++- .../Resources/GenerativeModelContract.php | 2 + .../GenerateContentRequest.php | 11 ++- src/Resources/GenerativeModel.php | 9 +++ .../Resources/GenerativeModelTestResource.php | 5 ++ tests/Resources/GenerativeModel.php | 67 +++++++++++++++++++ 6 files changed, 107 insertions(+), 5 deletions(-) diff --git a/src/Client.php b/src/Client.php index 45ac983..ab7d8b3 100644 --- a/src/Client.php +++ b/src/Client.php @@ -7,6 +7,8 @@ use Gemini\Contracts\ClientContract; use Gemini\Contracts\Resources\GenerativeModelContract; use Gemini\Contracts\TransporterContract; +use Gemini\Data\Content; +use Gemini\Data\GenerationConfig; use Gemini\Data\Model; use Gemini\Enums\ModelType; use Gemini\Resources\ChatSession; @@ -29,9 +31,19 @@ public function models(): Models return new Models(transporter: $this->transporter); } - public function generativeModel(ModelType|string $model): GenerativeModel - { - return new GenerativeModel(transporter: $this->transporter, model: $model); + public function generativeModel( + ModelType|string $model, + array $safetySettings = [], + ?GenerationConfig $generationConfig = null, + ?Content $systemInstruction = null + ): GenerativeModel { + return new GenerativeModel( + transporter: $this->transporter, + model: $model, + safetySettings: $safetySettings, + generationConfig: $generationConfig, + systemInstruction: $systemInstruction + ); } public function geminiPro(): GenerativeModel diff --git a/src/Contracts/Resources/GenerativeModelContract.php b/src/Contracts/Resources/GenerativeModelContract.php index 191a8de..d965b08 100644 --- a/src/Contracts/Resources/GenerativeModelContract.php +++ b/src/Contracts/Resources/GenerativeModelContract.php @@ -29,6 +29,8 @@ public function generateContent(string|Blob|array|Content ...$parts): GenerateCo */ public function streamGenerateContent(string|Blob|array|Content ...$parts): StreamResponse; + public function withSystemInstruction(Content $systemInstruction): self; + /** * @param array $history */ diff --git a/src/Requests/GenerativeModel/GenerateContentRequest.php b/src/Requests/GenerativeModel/GenerateContentRequest.php index 01d299c..0ec32d7 100644 --- a/src/Requests/GenerativeModel/GenerateContentRequest.php +++ b/src/Requests/GenerativeModel/GenerateContentRequest.php @@ -28,7 +28,8 @@ public function __construct( protected readonly string $model, protected readonly array $parts, protected readonly array $safetySettings = [], - protected readonly ?GenerationConfig $generationConfig = null + protected readonly ?GenerationConfig $generationConfig = null, + protected readonly ?Content $systemInstruction = null ) {} public function resolveEndpoint(): string @@ -43,7 +44,7 @@ public function resolveEndpoint(): string */ protected function defaultBody(): array { - return [ + $body = [ 'contents' => array_map( static fn (Content $content): array => $content->toArray(), $this->partsToContents(...$this->parts) @@ -54,5 +55,11 @@ protected function defaultBody(): array ), 'generationConfig' => $this->generationConfig?->toArray(), ]; + + if ($this->systemInstruction !== null) { + $body['system_instruction'] = $this->systemInstruction->toArray(); + } + + return $body; } } diff --git a/src/Resources/GenerativeModel.php b/src/Resources/GenerativeModel.php index b844314..c995465 100644 --- a/src/Resources/GenerativeModel.php +++ b/src/Resources/GenerativeModel.php @@ -34,6 +34,7 @@ public function __construct( ModelType|string $model, public array $safetySettings = [], public ?GenerationConfig $generationConfig = null, + public ?Content $systemInstruction = null, ) { $this->model = $this->parseModel(model: $model); } @@ -52,6 +53,13 @@ public function withGenerationConfig(GenerationConfig $generationConfig): self return $this; } + public function withSystemInstruction(Content $systemInstruction): self + { + $this->systemInstruction = $systemInstruction; + + return $this; + } + /** * Runs a model's tokenizer on input content and returns the token count. * @@ -83,6 +91,7 @@ public function generateContent(string|Blob|array|Content ...$parts): GenerateCo parts: $parts, safetySettings: $this->safetySettings, generationConfig: $this->generationConfig, + systemInstruction: $this->systemInstruction, ) ); diff --git a/src/Testing/Resources/GenerativeModelTestResource.php b/src/Testing/Resources/GenerativeModelTestResource.php index cc00a34..acfb6a5 100644 --- a/src/Testing/Resources/GenerativeModelTestResource.php +++ b/src/Testing/Resources/GenerativeModelTestResource.php @@ -42,4 +42,9 @@ public function startChat(array $history = []): ChatSession { return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model); } + + public function withSystemInstruction(Content $systemInstruction): self + { + return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model); + } } diff --git a/tests/Resources/GenerativeModel.php b/tests/Resources/GenerativeModel.php index 61e3b89..5d857f7 100644 --- a/tests/Resources/GenerativeModel.php +++ b/tests/Resources/GenerativeModel.php @@ -1,6 +1,8 @@ toBeInstanceOf(ChatSession::class); }); + +test('generative model with system instruction', function () { + $modelType = ModelType::GEMINI_PRO; + $systemInstruction = 'You are a helpful assistant.'; + $userMessage = 'Hello'; + + $mockTransporter = Mockery::mock(\Gemini\Contracts\TransporterContract::class); + $mockTransporter->shouldReceive('request') + ->once() + ->andReturnUsing(function ($request) use (&$capturedRequest) { + $capturedRequest = $request; + + return new ResponseDTO(GenerateContentResponse::fake()->toArray()); + }); + + $client = new Client($mockTransporter); + $model = $client->generativeModel(model: $modelType) + ->withSystemInstruction(Content::parse($systemInstruction)); + + $result = $model->generateContent($userMessage); + + expect($result)->toBeInstanceOf(GenerateContentResponse::class); + + expect($capturedRequest) + ->toBeInstanceOf(\Gemini\Requests\GenerativeModel\GenerateContentRequest::class) + ->and($capturedRequest->resolveEndpoint())->toBe("{$modelType->value}:generateContent"); + + $body = $capturedRequest->body(); + + expect($body) + ->toHaveKey('contents') + ->toHaveKey('system_instruction') + ->and($body['contents'][0]['parts'][0]['text'])->toBe($userMessage) + ->and($body['system_instruction']['parts'][0]['text'])->toBe($systemInstruction); + + expect($model) + ->toHaveProperty('systemInstruction') + ->and($model->systemInstruction)->toBeInstanceOf(Content::class) + ->and($model->systemInstruction->parts[0]->text)->toBe($systemInstruction); +}); + +test('system instruction is included in the request', function () { + $modelType = ModelType::GEMINI_PRO; + $systemInstruction = 'You are a helpful assistant.'; + + $mockTransporter = Mockery::mock(\Gemini\Contracts\TransporterContract::class); + $mockTransporter->shouldReceive('request') + ->once() + ->withArgs(function (\Gemini\Requests\GenerativeModel\GenerateContentRequest $request) use ($systemInstruction) { + $body = $request->body(); + + return $body['contents'][0]['parts'][0]['text'] === 'Hello' && + $body['system_instruction']['parts'][0]['text'] === $systemInstruction; + }) + ->andReturn(new ResponseDTO(GenerateContentResponse::fake()->toArray())); + + $client = new \Gemini\Client($mockTransporter); + + $parsedSystemInstruction = Content::parse($systemInstruction); + $generativeModel = $client->generativeModel(model: $modelType) + ->withSystemInstruction($parsedSystemInstruction); + + $generativeModel->generateContent('Hello'); +}); From 0ac39e593e0eda3b7234a3b5242bc2c66fa83cc1 Mon Sep 17 00:00:00 2001 From: Vytautas Smilingis Date: Tue, 29 Oct 2024 13:55:42 +0100 Subject: [PATCH 2/2] Adjusted PR according to comments --- src/Client.php | 19 +++---------------- .../GenerateContentRequest.php | 9 ++------- tests/Resources/GenerativeModel.php | 6 +++--- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/src/Client.php b/src/Client.php index ab7d8b3..fc5cb3f 100644 --- a/src/Client.php +++ b/src/Client.php @@ -7,9 +7,6 @@ use Gemini\Contracts\ClientContract; use Gemini\Contracts\Resources\GenerativeModelContract; use Gemini\Contracts\TransporterContract; -use Gemini\Data\Content; -use Gemini\Data\GenerationConfig; -use Gemini\Data\Model; use Gemini\Enums\ModelType; use Gemini\Resources\ChatSession; use Gemini\Resources\EmbeddingModel; @@ -31,19 +28,9 @@ public function models(): Models return new Models(transporter: $this->transporter); } - public function generativeModel( - ModelType|string $model, - array $safetySettings = [], - ?GenerationConfig $generationConfig = null, - ?Content $systemInstruction = null - ): GenerativeModel { - return new GenerativeModel( - transporter: $this->transporter, - model: $model, - safetySettings: $safetySettings, - generationConfig: $generationConfig, - systemInstruction: $systemInstruction - ); + public function generativeModel(ModelType|string $model): GenerativeModel + { + return new GenerativeModel(transporter: $this->transporter, model: $model); } public function geminiPro(): GenerativeModel diff --git a/src/Requests/GenerativeModel/GenerateContentRequest.php b/src/Requests/GenerativeModel/GenerateContentRequest.php index 0ec32d7..4e35a87 100644 --- a/src/Requests/GenerativeModel/GenerateContentRequest.php +++ b/src/Requests/GenerativeModel/GenerateContentRequest.php @@ -44,7 +44,7 @@ public function resolveEndpoint(): string */ protected function defaultBody(): array { - $body = [ + return [ 'contents' => array_map( static fn (Content $content): array => $content->toArray(), $this->partsToContents(...$this->parts) @@ -54,12 +54,7 @@ protected function defaultBody(): array $this->safetySettings ?? [] ), 'generationConfig' => $this->generationConfig?->toArray(), + 'systemInstruction' => $this->systemInstruction?->toArray(), ]; - - if ($this->systemInstruction !== null) { - $body['system_instruction'] = $this->systemInstruction->toArray(); - } - - return $body; } } diff --git a/tests/Resources/GenerativeModel.php b/tests/Resources/GenerativeModel.php index 5d857f7..e1eb40b 100644 --- a/tests/Resources/GenerativeModel.php +++ b/tests/Resources/GenerativeModel.php @@ -212,9 +212,9 @@ expect($body) ->toHaveKey('contents') - ->toHaveKey('system_instruction') + ->toHaveKey('systemInstruction') ->and($body['contents'][0]['parts'][0]['text'])->toBe($userMessage) - ->and($body['system_instruction']['parts'][0]['text'])->toBe($systemInstruction); + ->and($body['systemInstruction']['parts'][0]['text'])->toBe($systemInstruction); expect($model) ->toHaveProperty('systemInstruction') @@ -233,7 +233,7 @@ $body = $request->body(); return $body['contents'][0]['parts'][0]['text'] === 'Hello' && - $body['system_instruction']['parts'][0]['text'] === $systemInstruction; + $body['systemInstruction']['parts'][0]['text'] === $systemInstruction; }) ->andReturn(new ResponseDTO(GenerateContentResponse::fake()->toArray()));