Skip to content

Commit

Permalink
Merge pull request #45 from whatagraph/beta
Browse files Browse the repository at this point in the history
feat: Add system promt support for beta
  • Loading branch information
aydinfatih authored Oct 29, 2024
2 parents b35d97f + 0ac39e5 commit 0d023f0
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
use Gemini\Contracts\ClientContract;
use Gemini\Contracts\Resources\GenerativeModelContract;
use Gemini\Contracts\TransporterContract;
use Gemini\Data\Model;
use Gemini\Enums\ModelType;
use Gemini\Resources\ChatSession;
use Gemini\Resources\EmbeddingModel;
Expand Down
2 changes: 2 additions & 0 deletions src/Contracts/Resources/GenerativeModelContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public function generateContent(string|Blob|array|Content ...$parts): GenerateCo
*/
public function streamGenerateContent(string|Blob|array|Content ...$parts): StreamResponse;

public function withSystemInstruction(Content $systemInstruction): self;

/**
* @param array<Content> $history
*/
Expand Down
4 changes: 3 additions & 1 deletion src/Requests/GenerativeModel/GenerateContentRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public function __construct(
protected readonly string $model,
protected readonly array $parts,
protected readonly array $safetySettings = [],
protected readonly ?GenerationConfig $generationConfig = null
protected readonly ?GenerationConfig $generationConfig = null,
protected readonly ?Content $systemInstruction = null
) {}

public function resolveEndpoint(): string
Expand All @@ -53,6 +54,7 @@ protected function defaultBody(): array
$this->safetySettings ?? []
),
'generationConfig' => $this->generationConfig?->toArray(),
'systemInstruction' => $this->systemInstruction?->toArray(),
];
}
}
9 changes: 9 additions & 0 deletions src/Resources/GenerativeModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public function __construct(
ModelType|string $model,
public array $safetySettings = [],
public ?GenerationConfig $generationConfig = null,
public ?Content $systemInstruction = null,
) {
$this->model = $this->parseModel(model: $model);
}
Expand All @@ -52,6 +53,13 @@ public function withGenerationConfig(GenerationConfig $generationConfig): self
return $this;
}

public function withSystemInstruction(Content $systemInstruction): self
{
$this->systemInstruction = $systemInstruction;

return $this;
}

/**
* Runs a model's tokenizer on input content and returns the token count.
*
Expand Down Expand Up @@ -83,6 +91,7 @@ public function generateContent(string|Blob|array|Content ...$parts): GenerateCo
parts: $parts,
safetySettings: $this->safetySettings,
generationConfig: $this->generationConfig,
systemInstruction: $this->systemInstruction,
)
);

Expand Down
5 changes: 5 additions & 0 deletions src/Testing/Resources/GenerativeModelTestResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ public function startChat(array $history = []): ChatSession
{
return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model);
}

public function withSystemInstruction(Content $systemInstruction): self
{
return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model);
}
}
67 changes: 67 additions & 0 deletions tests/Resources/GenerativeModel.php
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
<?php

use Gemini\Client;
use Gemini\Data\Candidate;
use Gemini\Data\Content;
use Gemini\Data\GenerationConfig;
use Gemini\Data\PromptFeedback;
use Gemini\Data\SafetySetting;
Expand All @@ -13,6 +15,7 @@
use Gemini\Responses\GenerativeModel\CountTokensResponse;
use Gemini\Responses\GenerativeModel\GenerateContentResponse;
use Gemini\Responses\StreamResponse;
use Gemini\Transporters\DTOs\ResponseDTO;
use GuzzleHttp\Psr7\Response;
use GuzzleHttp\Psr7\Stream;

Expand Down Expand Up @@ -178,3 +181,67 @@
expect($result)
->toBeInstanceOf(ChatSession::class);
});

test('generative model with system instruction', function () {
$modelType = ModelType::GEMINI_PRO;
$systemInstruction = 'You are a helpful assistant.';
$userMessage = 'Hello';

$mockTransporter = Mockery::mock(\Gemini\Contracts\TransporterContract::class);
$mockTransporter->shouldReceive('request')
->once()
->andReturnUsing(function ($request) use (&$capturedRequest) {
$capturedRequest = $request;

return new ResponseDTO(GenerateContentResponse::fake()->toArray());
});

$client = new Client($mockTransporter);
$model = $client->generativeModel(model: $modelType)
->withSystemInstruction(Content::parse($systemInstruction));

$result = $model->generateContent($userMessage);

expect($result)->toBeInstanceOf(GenerateContentResponse::class);

expect($capturedRequest)
->toBeInstanceOf(\Gemini\Requests\GenerativeModel\GenerateContentRequest::class)
->and($capturedRequest->resolveEndpoint())->toBe("{$modelType->value}:generateContent");

$body = $capturedRequest->body();

expect($body)
->toHaveKey('contents')
->toHaveKey('systemInstruction')
->and($body['contents'][0]['parts'][0]['text'])->toBe($userMessage)
->and($body['systemInstruction']['parts'][0]['text'])->toBe($systemInstruction);

expect($model)
->toHaveProperty('systemInstruction')
->and($model->systemInstruction)->toBeInstanceOf(Content::class)
->and($model->systemInstruction->parts[0]->text)->toBe($systemInstruction);
});

test('system instruction is included in the request', function () {
$modelType = ModelType::GEMINI_PRO;
$systemInstruction = 'You are a helpful assistant.';

$mockTransporter = Mockery::mock(\Gemini\Contracts\TransporterContract::class);
$mockTransporter->shouldReceive('request')
->once()
->withArgs(function (\Gemini\Requests\GenerativeModel\GenerateContentRequest $request) use ($systemInstruction) {
$body = $request->body();

return $body['contents'][0]['parts'][0]['text'] === 'Hello' &&
$body['systemInstruction']['parts'][0]['text'] === $systemInstruction;
})
->andReturn(new ResponseDTO(GenerateContentResponse::fake()->toArray()));

$client = new \Gemini\Client($mockTransporter);

$parsedSystemInstruction = Content::parse($systemInstruction);
$generativeModel = $client->generativeModel(model: $modelType)
->withSystemInstruction($parsedSystemInstruction);

$generativeModel->generateContent('Hello');
});

0 comments on commit 0d023f0

Please sign in to comment.