Skip to content

Commit

Permalink
Merge pull request #57 from whatagraph/main
Browse files Browse the repository at this point in the history
Align `GenerativeModelTestResource` with `GenerativeModel`
  • Loading branch information
aydinfatih authored Dec 29, 2024
2 parents 771bf85 + 16c3244 commit 6dca9a5
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/Contracts/Resources/GenerativeModelContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

use Gemini\Data\Blob;
use Gemini\Data\Content;
use Gemini\Data\GenerationConfig;
use Gemini\Data\SafetySetting;
use Gemini\Resources\ChatSession;
use Gemini\Responses\GenerativeModel\CountTokensResponse;
use Gemini\Responses\GenerativeModel\GenerateContentResponse;
Expand Down Expand Up @@ -33,4 +35,8 @@ public function streamGenerateContent(string|Blob|array|Content ...$parts): Stre
* @param array<Content> $history
*/
public function startChat(array $history = []): ChatSession;

public function withSafetySetting(SafetySetting $safetySetting): self;

public function withGenerationConfig(GenerationConfig $generationConfig): self;
}
80 changes: 80 additions & 0 deletions src/Testing/ClientFake.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Gemini\Contracts\ResponseContract;
use Gemini\Enums\ModelType;
use Gemini\Responses\StreamResponse;
use Gemini\Testing\FunctionCalls\TestFunctionCall;
use Gemini\Testing\Requests\TestRequest;
use Gemini\Testing\Resources\ChatSessionTestResource;
use Gemini\Testing\Resources\EmbeddingModelTestResource;
Expand All @@ -23,6 +24,11 @@ class ClientFake implements ClientContract
*/
private array $requests = [];

/**
* @var array<array-key, TestFunctionCall>
*/
private array $functionCalls = [];

/**
* @param array<array-key, ResponseContract> $responses
*/
Expand Down Expand Up @@ -125,6 +131,80 @@ public function record(TestRequest $request): ResponseContract|StreamResponse
return $response;
}

public function assertFunctionCalled(string $resource, ModelType|string|null $model = null, callable|int|null $callback = null): void
{
if (is_int($callback)) {
$this->assertFunctionCalledTimes(resource: $resource, model: $model, times: $callback);

return;
}

PHPUnit::assertTrue(
$this->functionCalled(resource: $resource, model: $model, callback: $callback) !== [],
"The expected [{$resource}] function was not called."
);
}

private function assertFunctionCalledTimes(string $resource, ModelType|string|null $model = null, int $times = 1): void
{
$count = count($this->functionCalled(resource: $resource, model: $model));

PHPUnit::assertSame(
$times, $count,
"The expected [{$resource}] resource was called {$count} times instead of {$times} times."
);
}

/**
* @return mixed[]
*/
private function functionCalled(string $resource, ModelType|string|null $model = null, ?callable $callback = null): array
{
if (! $this->hasFunctionCalled(resource: $resource, model: $model)) {
return [];
}

$callback = $callback ?: fn (): bool => true;

return array_filter($this->resourcesOfFunctionCalls(type: $resource), fn (TestFunctionCall $functionCall) => $callback($functionCall->method(), $functionCall->args()));
}

private function hasFunctionCalled(string $resource, ModelType|string|null $model = null): bool
{
return $this->resourcesOfFunctionCalls(type: $resource, model: $model) !== [];
}

public function assertFunctionNotCalled(string $resource, ModelType|string|null $model = null, ?callable $callback = null): void
{
PHPUnit::assertCount(
0, $this->functionCalled(resource: $resource, model: $model, callback: $callback),
"The unexpected [{$resource}] function was called."
);
}

public function assertNoFunctionsCalled(): void
{
$resourceNames = implode(
separator: ', ',
array: array_map(fn (TestFunctionCall $functionCall): string => $functionCall->resource(), $this->functionCalls)
);

PHPUnit::assertEmpty($this->functionCalls, 'The following functions were called unexpectedly: '.$resourceNames);
}

/**
* @return array<array-key, TestFunctionCall>
*/
private function resourcesOfFunctionCalls(string $type, ModelType|string|null $model = null): array
{
return array_filter($this->functionCalls, fn (TestFunctionCall $functionCall): bool => $functionCall->resource() === $type && ($model === null || $functionCall->model() === $model));
}

public function recordFunctionCall(TestFunctionCall $call): void
{
$this->functionCalls[] = $call;
}

public function models(): ModelTestResource
{
return new ModelTestResource(fake: $this);
Expand Down
38 changes: 38 additions & 0 deletions src/Testing/FunctionCalls/TestFunctionCall.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<?php

declare(strict_types=1);

namespace Gemini\Testing\FunctionCalls;

use Gemini\Enums\ModelType;

final class TestFunctionCall
{
/**
* @param array<string, mixed> $args
*/
public function __construct(protected string $resource, protected string $method, protected array $args, protected ModelType|string|null $model = null) {}

public function resource(): string
{
return $this->resource;
}

public function method(): string
{
return $this->method;
}

/**
* @return array<string, mixed>
*/
public function args(): array
{
return $this->args;
}

public function model(): ModelType|string|null
{
return $this->model;
}
}
16 changes: 16 additions & 0 deletions src/Testing/Resources/Concerns/Testable.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Gemini\Enums\ModelType;
use Gemini\Responses\StreamResponse;
use Gemini\Testing\ClientFake;
use Gemini\Testing\FunctionCalls\TestFunctionCall;
use Gemini\Testing\Requests\TestRequest;

trait Testable
Expand All @@ -30,4 +31,19 @@ public function assertNotSent(callable|int|null $callback = null): void
{
$this->fake->assertNotSent(resource: $this->resource(), model: $this->model, callback: $callback);
}

public function recordFunctionCall(string $method, array $args = [], ModelType|string|null $model = null): void
{
$this->fake->recordFunctionCall(new TestFunctionCall(resource: $this->resource(), method: $method, args: $args, model: $model));
}

public function assertFunctionCalled(callable|int|null $callback = null): void
{
$this->fake->assertFunctionCalled(resource: $this->resource(), model: $this->model, callback: $callback);
}

public function assertFunctionNotCalled(callable|int|null $callback = null): void
{
$this->fake->assertFunctionNotCalled(resource: $this->resource(), model: $this->model, callback: $callback);
}
}
16 changes: 16 additions & 0 deletions src/Testing/Resources/GenerativeModelTestResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
use Gemini\Contracts\Resources\GenerativeModelContract;
use Gemini\Data\Blob;
use Gemini\Data\Content;
use Gemini\Data\GenerationConfig;
use Gemini\Data\SafetySetting;
use Gemini\Resources\ChatSession;
use Gemini\Resources\GenerativeModel;
use Gemini\Responses\GenerativeModel\CountTokensResponse;
Expand Down Expand Up @@ -42,4 +44,18 @@ public function startChat(array $history = []): ChatSession
{
return $this->record(method: __FUNCTION__, args: func_get_args(), model: $this->model);
}

public function withSafetySetting(SafetySetting $safetySetting): self
{
$this->recordFunctionCall(method: __FUNCTION__, args: func_get_args(), model: $this->model);

return $this;
}

public function withGenerationConfig(GenerationConfig $generationConfig): self
{
$this->recordFunctionCall(method: __FUNCTION__, args: func_get_args(), model: $this->model);

return $this;
}
}
49 changes: 49 additions & 0 deletions tests/Testing/Resources/GenerativeModelTestResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

declare(strict_types=1);

use Gemini\Data\GenerationConfig;
use Gemini\Data\SafetySetting;
use Gemini\Enums\HarmBlockThreshold;
use Gemini\Enums\HarmCategory;
use Gemini\Responses\GenerativeModel\CountTokensResponse;
use Gemini\Responses\GenerativeModel\GenerateContentResponse;
use Gemini\Testing\ClientFake;
Expand Down Expand Up @@ -44,3 +48,48 @@
$parameters[0] === 'Hello';
});
});

it('records a "withSafetySetting" function call', function () {
$fake = new ClientFake;

$safetySetting = new SafetySetting(HarmCategory::HARM_CATEGORY_DANGEROUS, HarmBlockThreshold::BLOCK_ONLY_HIGH);

$fake->geminiPro()->withSafetySetting($safetySetting);

$fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($safetySetting) {
return $method === 'withSafetySetting' &&
$parameters[0] === $safetySetting;
});
});

it('records a "withGenerationConfig" function call', function () {
$fake = new ClientFake;

$generationConfig = new GenerationConfig;

$fake->geminiPro()->withGenerationConfig($generationConfig);

$fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($generationConfig) {
return $method === 'withGenerationConfig' &&
$parameters[0] === $generationConfig;
});
});

it('records both content request and function call', function () {
$fake = new ClientFake([
GenerateContentResponse::fake(),
]);

$generationConfig = new GenerationConfig;

$fake->geminiPro()->withGenerationConfig($generationConfig)->generateContent('Hello');

$fake->geminiPro()->assertSent(function (string $method, array $parameters) {
return $method === 'generateContent' &&
$parameters[0] === 'Hello';
});
$fake->geminiPro()->assertFunctionCalled(function (string $method, array $parameters) use ($generationConfig) {
return $method === 'withGenerationConfig' &&
$parameters[0] === $generationConfig;
});
});

0 comments on commit 6dca9a5

Please sign in to comment.