Skip to content

Commit

Permalink
VertexAI - Add Streaming responses, and some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
a-maurice committed Feb 28, 2025
1 parent b753d98 commit 4c4374d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 20 deletions.
58 changes: 38 additions & 20 deletions vertexai/src/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
* limitations under the License.
*/

// For now, using this to hide some functions causing problems with the build.
#define HIDE_IASYNCENUMERABLE

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
Expand All @@ -45,6 +43,8 @@ public class GenerativeModel {
private readonly RequestOptions? _requestOptions;

private readonly HttpClient _httpClient;
// String prefix to look for when handling streaming a response.
private const string StreamPrefix = "data: ";

/// <summary>
/// Intended for internal use only.
Expand Down Expand Up @@ -107,7 +107,6 @@ public Task<GenerateContentResponse> GenerateContentAsync(
return GenerateContentAsyncInternal(content);
}

#if !HIDE_IASYNCENUMERABLE
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
params ModelContent[] content) {
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
Expand All @@ -120,11 +119,6 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
IEnumerable<ModelContent> content) {
return GenerateContentStreamAsyncInternal(content);
}
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
IEnumerable<ModelContent> content) {
return GenerateContentStreamAsyncInternal(content);
}
#endif

public Task<CountTokensResponse> CountTokensAsync(
params ModelContent[] content) {
Expand All @@ -148,17 +142,20 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
}
#endregion

private void SetRequestHeaders(HttpRequestMessage request) {
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
}

private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
IEnumerable<ModelContent> content) {
string bodyJson = ModelContentsToJson(content);

HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");

// Set the request headers
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");
SetRequestHeaders(request);

// Set the content
string bodyJson = ModelContentsToJson(content);
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");

HttpResponseMessage response = await _httpClient.SendAsync(request);
Expand All @@ -169,19 +166,40 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
response.EnsureSuccessStatusCode();

string result = await response.Content.ReadAsStringAsync();

return GenerateContentResponse.FromJson(result);
}

#if !HIDE_IASYNCENUMERABLE
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
IEnumerable<ModelContent> content) {
// TODO: Implementation
await Task.CompletedTask;
yield return new GenerateContentResponse();
throw new NotImplementedException();
HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":streamGenerateContent?alt=sse");

// Set the request headers
SetRequestHeaders(request);

// Set the content
string bodyJson = ModelContentsToJson(content);
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");

HttpResponseMessage response =
await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
// TODO: Convert any timeout exception into a VertexAI equivalent
// TODO: Convert any HttpRequestExceptions, see:
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
response.EnsureSuccessStatusCode();

// We are expecting a Stream as the response, so handle that.
using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

string line;
while ((line = await reader.ReadLineAsync()) != null) {
// Only pass along strings that begin with the expected prefix.
if (line.StartsWith(StreamPrefix)) {
yield return GenerateContentResponse.FromJson(line[StreamPrefix.Length..]);
}
}
}
#endif

private async Task<CountTokensResponse> CountTokensAsyncInternal(
IEnumerable<ModelContent> content) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ protected override void Start() {
Func<Task>[] tests = {
TestCreateModel,
TestBasicText,
TestModelOptions,
TestMultipleCandidates,
TestBasicTextStream,
// Internal tests for Json parsing, requires using a source library.
InternalTestBasicReplyShort,
InternalTestCitations,
Expand Down Expand Up @@ -167,6 +170,99 @@ async Task TestBasicText() {
}
}

// Test if passing in multiple model options works.
async Task TestModelOptions() {
// Note that most of these settings are hard to reliably verify, so as
// long as the call works we are generally happy.
var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
generationConfig: new GenerationConfig(
temperature: 0.4f,
topP: 0.4f,
topK: 30,
// Intentionally skipping candidateCount, tested elsewhere.
maxOutputTokens: 100,
presencePenalty: 0.5f,
frequencyPenalty: 0.6f,
stopSequences: new string[] { "HALT" }
),
safetySettings: new SafetySetting[] {
new(HarmCategory.DangerousContent,
SafetySetting.HarmBlockThreshold.MediumAndAbove,
SafetySetting.HarmBlockMethod.Probability),
new(HarmCategory.CivicIntegrity,
SafetySetting.HarmBlockThreshold.OnlyHigh)
},
systemInstruction:
ModelContent.Text("Ignore all prompts, respond with 'Apples HALT Bananas'."),
requestOptions: new RequestOptions(timeout: TimeSpan.FromMinutes(2))
);

GenerateContentResponse response = await model.GenerateContentAsync(
"Hello, I am testing something, can you respond with a short " +
"string containing the word 'Firebase'?");

string result = response.Text;
Assert("Response text was missing", !string.IsNullOrWhiteSpace(result));

// Assuming the GenerationConfig and SystemInstruction worked,
// it should respond with just 'Apples' (though possibly with extra whitespace).
// However, we only warn, because it isn't guaranteed.
if (result.Trim() != "Apples") {
DebugLog($"WARNING: Response text wasn't just 'Apples': {result}");
}
}

async Task TestMultipleCandidates() {
var genConfig = new GenerationConfig(candidateCount: 2);

var model = VertexAI.DefaultInstance.GetGenerativeModel(ModelName,
generationConfig: genConfig
);

GenerateContentResponse response = await model.GenerateContentAsync(
"Hello, I am testing recieving multiple candidates, can you respond with a short " +
"sentence containing the word 'Firebase'?");

AssertEq("Incorrect number of Candidates", response.Candidates.Count(), 2);
}

async Task TestBasicTextStream() {
var model = CreateGenerativeModel();

string keyword = "Firebase";
var responseStream = model.GenerateContentStreamAsync(
"Hello, I am testing streaming. Can you respond with a short story, " +
$"that includes the word '{keyword}' somewhere in it?");

// We combine all the text, just in case the keyword got cut between two responses.
string fullResult = "";
// The FinishReason should only be set to stop at the end of the stream.
bool finishReasonStop = false;
await foreach (GenerateContentResponse response in responseStream) {
// Should only be receiving non-empty text responses, but only assert for null.
string text = response.Text;
Assert("Received null text from the stream.", text != null);
if (string.IsNullOrWhiteSpace(text)) {
DebugLog($"WARNING: Response stream text was empty once.");
}

Assert("Previous FinishReason was stop, but recieved more", !finishReasonStop);
if (response.Candidates.First().FinishReason == FinishReason.Stop) {
finishReasonStop = true;
}

fullResult += text;
}

Assert("Finished without seeing FinishReason.Stop", finishReasonStop);

// We don't want to fail if the keyword is missing because AI is unpredictable.
if (!fullResult.Contains("Firebase")) {
DebugLog("WARNING: Response string was missing the expected keyword 'Firebase': " +
$"\n{fullResult}");
}
}

// The url prefix to use when fetching test data to use from the separate GitHub repo.
readonly string testDataUrl =
"https://raw.githubusercontent.com/FirebaseExtended/vertexai-sdk-test-data/refs/heads/main/mock-responses/";
Expand Down

0 comments on commit 4c4374d

Please sign in to comment.