diff --git a/src/Contracts/Resources/GenerativeModelContract.php b/src/Contracts/Resources/GenerativeModelContract.php index 191a8de..dd97e5d 100644 --- a/src/Contracts/Resources/GenerativeModelContract.php +++ b/src/Contracts/Resources/GenerativeModelContract.php @@ -6,6 +6,8 @@ use Gemini\Data\Blob; use Gemini\Data\Content; +use Gemini\Data\GenerationConfig; +use Gemini\Data\SafetySetting; use Gemini\Resources\ChatSession; use Gemini\Responses\GenerativeModel\CountTokensResponse; use Gemini\Responses\GenerativeModel\GenerateContentResponse; @@ -33,4 +35,8 @@ public function streamGenerateContent(string|Blob|array|Content ...$parts): Stre * @param array $history */ public function startChat(array $history = []): ChatSession; + + public function withSafetySetting(SafetySetting $safetySetting): self; + + public function withGenerationConfig(GenerationConfig $generationConfig): self; } diff --git a/src/Testing/ClientFake.php b/src/Testing/ClientFake.php index c491906..a4b914f 100644 --- a/src/Testing/ClientFake.php +++ b/src/Testing/ClientFake.php @@ -8,6 +8,7 @@ use Gemini\Contracts\ResponseContract; use Gemini\Enums\ModelType; use Gemini\Responses\StreamResponse; +use Gemini\Testing\FunctionCalls\TestFunctionCall; use Gemini\Testing\Requests\TestRequest; use Gemini\Testing\Resources\ChatSessionTestResource; use Gemini\Testing\Resources\EmbeddingModelTestResource; @@ -23,6 +24,11 @@ class ClientFake implements ClientContract */ private array $requests = []; + /** + * @var array + */ + private array $functionCalls = []; + /** * @param array $responses */ @@ -125,6 +131,80 @@ public function record(TestRequest $request): ResponseContract|StreamResponse return $response; } + public function assertFunctionCalled(string $resource, ModelType|string|null $model = null, callable|int|null $callback = null): void + { + if (is_int($callback)) { + $this->assertFunctionCalledTimes(resource: $resource, model: $model, times: $callback); + + return; + } + + PHPUnit::assertTrue( + $this->functionCalled(resource: $resource, model: $model, callback: $callback) !== [], + "The expected [{$resource}] function was not called." + ); + } + + private function assertFunctionCalledTimes(string $resource, ModelType|string|null $model = null, int $times = 1): void + { + $count = count($this->functionCalled(resource: $resource, model: $model)); + + PHPUnit::assertSame( + $times, $count, + "The expected [{$resource}] resource was called {$count} times instead of {$times} times." + ); + } + + /** + * @return mixed[] + */ + private function functionCalled(string $resource, ModelType|string|null $model = null, ?callable $callback = null): array + { + if (! $this->hasFunctionCalled(resource: $resource, model: $model)) { + return []; + } + + $callback = $callback ?: fn (): bool => true; + + return array_filter($this->resourcesOfFunctionCalls(type: $resource), fn (TestFunctionCall $functionCall) => $callback($functionCall->method(), $functionCall->args())); + } + + private function hasFunctionCalled(string $resource, ModelType|string|null $model = null): bool + { + return $this->resourcesOfFunctionCalls(type: $resource, model: $model) !== []; + } + + public function assertFunctionNotCalled(string $resource, ModelType|string|null $model = null, ?callable $callback = null): void + { + PHPUnit::assertCount( + 0, $this->functionCalled(resource: $resource, model: $model, callback: $callback), + "The unexpected [{$resource}] function was called." + ); + } + + public function assertNoFunctionsCalled(): void + { + $resourceNames = implode( + separator: ', ', + array: array_map(fn (TestFunctionCall $functionCall): string => $functionCall->resource(), $this->functionCalls) + ); + + PHPUnit::assertEmpty($this->functionCalls, 'The following functions were called unexpectedly: '.$resourceNames); + } + + /** + * @return array + */ + private function resourcesOfFunctionCalls(string $type, ModelType|string|null $model = null): array + { + return array_filter($this->functionCalls, fn (TestFunctionCall $functionCall): bool => $functionCall->resource() === $type && ($model === null || $functionCall->model() === $model)); + } + + public function recordFunctionCall(TestFunctionCall $call): void + { + $this->functionCalls[] = $call; + } + public function models(): ModelTestResource { return new ModelTestResource(fake: $this); diff --git a/src/Testing/FunctionCalls/TestFunctionCall.php b/src/Testing/FunctionCalls/TestFunctionCall.php new file mode 100644 index 0000000..44ef1c4 --- /dev/null +++ b/src/Testing/FunctionCalls/TestFunctionCall.php @@ -0,0 +1,38 @@ + $args + */ + public function __construct(protected string $resource, protected string $method, protected array $args, protected ModelType|string|null $model = null) {} + + public function resource(): string + { + return $this->resource; + } + + public function method(): string + { + return $this->method; + } + + /** + * @return array + */ + public function args(): array + { + return $this->args; + } + + public function model(): ModelType|string|null + { + return $this->model; + } +} diff --git a/src/Testing/Resources/Concerns/Testable.php b/src/Testing/Resources/Concerns/Testable.php index 89e9701..b6912c2 100644 --- a/src/Testing/Resources/Concerns/Testable.php +++ b/src/Testing/Resources/Concerns/Testable.php @@ -8,6 +8,7 @@ use Gemini\Enums\ModelType; use Gemini\Responses\StreamResponse; use Gemini\Testing\ClientFake; +use Gemini\Testing\FunctionCalls\TestFunctionCall; use Gemini\Testing\Requests\TestRequest; trait Testable @@ -30,4 +31,19 @@ public function assertNotSent(callable|int|null $callback = null): void { $this->fake->assertNotSent(resource: $this->resource(), model: $this->model, callback: $callback); } + + public function recordFunctionCall(string $method, array $args = [], ModelType|string|null $model = null): void + { + $this->fake->recordFunctionCall(new TestFunctionCall(resource: $this->resource(), method: $method, args: $args, model: $model)); + } + + public function assertFunctionCalled(callable|int|null $callback = null): void + { + $this->fake->assertFunctionCalled(resource: $this->resource(), model: $this->model, callback: $callback); + } + + public function assertFunctionNotCalled(callable|int|null $callback = null): void + { + $this->fake->assertFunctionNotCalled(resource: $this->resource(), model: $this->model, callback: $callback); + } } diff --git a/src/Testing/Resources/GenerativeModelTestResource.php b/src/Testing/Resources/GenerativeModelTestResource.php index cc00a34..7f92723 100644 --- a/src/Testing/Resources/GenerativeModelTestResource.php +++ b/src/Testing/Resources/GenerativeModelTestResource.php @@ -7,6 +7,8 @@ use Gemini\Contracts\Resources\GenerativeModelContract; use Gemini\Data\Blob; use Gemini\Data\Content; +use Gemini\Data\GenerationConfig; +use Gemini\Data\SafetySetting; use Gemini\Resources\ChatSession; use Gemini\Resources\GenerativeModel; use Gemini\Responses\GenerativeModel\CountTokensResponse; @@ -42,4 +44,18 @@ public function startChat(array $history = []): ChatSession { return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model); } + + public function withSafetySetting(SafetySetting $safetySetting): self + { + $this->recordFunctionCall(method: __FUNCTION__, args: func_get_args(), model: $this->model); + + return $this; + } + + public function withGenerationConfig(GenerationConfig $generationConfig): self + { + $this->recordFunctionCall(method: __FUNCTION__, args: func_get_args(), model: $this->model); + + return $this; + } } diff --git a/tests/Testing/Resources/GenerativeModelTestResource.php b/tests/Testing/Resources/GenerativeModelTestResource.php index d84dfc6..a50383b 100644 --- a/tests/Testing/Resources/GenerativeModelTestResource.php +++ b/tests/Testing/Resources/GenerativeModelTestResource.php @@ -2,6 +2,10 @@ declare(strict_types=1); +use Gemini\Data\GenerationConfig; +use Gemini\Data\SafetySetting; +use Gemini\Enums\HarmBlockThreshold; +use Gemini\Enums\HarmCategory; use Gemini\Responses\GenerativeModel\CountTokensResponse; use Gemini\Responses\GenerativeModel\GenerateContentResponse; use Gemini\Testing\ClientFake; @@ -44,3 +48,48 @@ $parameters[0] === 'Hello'; }); }); + +it('records a "withSafetySetting" function call', function () { + $fake = new ClientFake; + + $safetySetting = new SafetySetting(HarmCategory::HARM_CATEGORY_DANGEROUS, HarmBlockThreshold::BLOCK_ONLY_HIGH); + + $fake->geminiPro()->withSafetySetting($safetySetting); + + $fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($safetySetting) { + return $method === 'withSafetySetting' && + $parameters[0] === $safetySetting; + }); +}); + +it('records a "withGenerationConfig" function call', function () { + $fake = new ClientFake; + + $generationConfig = new GenerationConfig; + + $fake->geminiPro()->withGenerationConfig($generationConfig); + + $fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($generationConfig) { + return $method === 'withGenerationConfig' && + $parameters[0] === $generationConfig; + }); +}); + +it('records both content request and function call', function () { + $fake = new ClientFake([ + GenerateContentResponse::fake(), + ]); + + $generationConfig = new GenerationConfig; + + $fake->geminiPro()->withGenerationConfig($generationConfig)->generateContent('Hello'); + + $fake->geminiPro()->assertSent(function (string $method, array $parameters) { + return $method === 'generateContent' && + $parameters[0] === 'Hello'; + }); + $fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($generationConfig) { + return $method === 'withGenerationConfig' && + $parameters[0] === $generationConfig; + }); +});