Skip to content

Commit

Permalink
VertexAI - Add logic for text input/output (#1189)
Browse files Browse the repository at this point in the history
* VertexAI - Add logic for text input/output

* Remove debug logs

* Update GenerativeModel.cs

* Address feedback
  • Loading branch information
a-maurice authored Feb 24, 2025
1 parent ba46a89 commit f5863f8
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 64 deletions.
49 changes: 48 additions & 1 deletion vertexai/src/Candidate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;

namespace Firebase.VertexAI {

Expand All @@ -32,11 +33,57 @@ public enum FinishReason {
MalformedFunctionCall,
}

/// <summary>
/// A struct representing a possible reply to a content generation prompt.
/// Each content generation prompt may produce multiple candidate responses.
/// </summary>
public readonly struct Candidate {
private readonly ReadOnlyCollection<SafetyRating> _safetyRatings;

/// <summary>
/// The response’s content.
/// </summary>
public ModelContent Content { get; }
public IEnumerable<SafetyRating> SafetyRatings { get; }

/// <summary>
/// The safety rating of the response content.
/// </summary>
public IEnumerable<SafetyRating> SafetyRatings =>
_safetyRatings ?? new ReadOnlyCollection<SafetyRating>(new List<SafetyRating>());

/// <summary>
/// The reason the model stopped generating content, if it exists;
/// for example, if the model generated a predefined stop sequence.
/// </summary>
public FinishReason? FinishReason { get; }

/// <summary>
/// Cited works in the model’s response content, if it exists.
/// </summary>
public CitationMetadata? CitationMetadata { get; }

// Hidden constructor, users don't need to make this, though they still technically can.
internal Candidate(ModelContent content, List<SafetyRating> safetyRatings,
FinishReason? finishReason, CitationMetadata? citationMetadata) {
Content = content;
_safetyRatings = new ReadOnlyCollection<SafetyRating>(safetyRatings ?? new List<SafetyRating>());
FinishReason = finishReason;
CitationMetadata = citationMetadata;
}

internal static Candidate FromJson(Dictionary<string, object> jsonDict) {
ModelContent content = new();
if (jsonDict.TryGetValue("content", out object contentObj)) {
if (contentObj is not Dictionary<string, object> contentDict) {
throw new VertexAISerializationException("Invalid JSON format: 'content' is not a dictionary.");
}
// We expect this to be another dictionary to convert
content = ModelContent.FromJson(contentDict);
}

// TODO: Parse SafetyRatings, FinishReason, and CitationMetadata
return new Candidate(content, null, null, null);
}
}

}
79 changes: 73 additions & 6 deletions vertexai/src/GenerateContentResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,87 @@
*/

using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Google.MiniJSON;

namespace Firebase.VertexAI {

/// <summary>
/// The model's response to a generate content request.
/// </summary>
public readonly struct GenerateContentResponse {
public IEnumerable<Candidate> Candidates { get; }
private readonly ReadOnlyCollection<Candidate> _candidates;

/// <summary>
/// A list of candidate response content, ordered from best to worst.
/// </summary>
public IEnumerable<Candidate> Candidates =>
_candidates ?? new ReadOnlyCollection<Candidate>(new List<Candidate>());

/// <summary>
/// A value containing the safety ratings for the response, or,
/// if the request was blocked, a reason for blocking the request.
/// </summary>
public PromptFeedback? PromptFeedback { get; }

/// <summary>
/// Token usage metadata for processing the generate content request.
/// </summary>
public UsageMetadata? UsageMetadata { get; }

// Helper properties
// The response's content as text, if it exists
public string Text { get; }
/// <summary>
/// The response's content as text, if it exists.
/// </summary>
public string Text {
get {
// Concatenate all of the text parts from the first candidate.
return string.Join(" ",
Candidates.FirstOrDefault().Content.Parts
.OfType<ModelContent.TextPart>().Select(tp => tp.Text));
}
}

/// <summary>
/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
/// </summary>
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls {
get {
return Candidates.FirstOrDefault().Content.Parts.OfType<ModelContent.FunctionCallPart>();
}
}

// Hidden constructor, users don't need to make this, though they still technically can.
internal GenerateContentResponse(List<Candidate> candidates, PromptFeedback? promptFeedback,
UsageMetadata? usageMetadata) {
_candidates = new ReadOnlyCollection<Candidate>(candidates ?? new List<Candidate>());
PromptFeedback = promptFeedback;
UsageMetadata = usageMetadata;
}

internal static GenerateContentResponse FromJson(string jsonString) {
return FromJson(Json.Deserialize(jsonString) as Dictionary<string, object>);
}

internal static GenerateContentResponse FromJson(Dictionary<string, object> jsonDict) {
// Parse the Candidates
List<Candidate> candidates = new();
if (jsonDict.TryGetValue("candidates", out object candidatesObject)) {
if (candidatesObject is not List<object> listOfCandidateObjects) {
throw new VertexAISerializationException("Invalid JSON format: 'candidates' is not a list.");
}

candidates = listOfCandidateObjects
.Select(o => o as Dictionary<string, object>)
.Where(dict => dict != null)
.Select(Candidate.FromJson)
.ToList();
}

// TODO: Parse PromptFeedback and UsageMetadata

// Returns function calls found in any Parts of the first candidate of the response, if any.
public IEnumerable<ModelContent.FunctionCallPart> FunctionCalls { get; }
return new GenerateContentResponse(candidates, null, null);
}
}

public enum BlockReason {
Expand Down
168 changes: 147 additions & 21 deletions vertexai/src/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,190 @@
* limitations under the License.
*/

// For now, using this to hide some functions causing problems with the build.
#define HIDE_IASYNCENUMERABLE

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using Google.MiniJSON;

namespace Firebase.VertexAI {

/// <summary>
/// A type that represents a remote multimodal model (like Gemini), with the ability to generate
/// content based on various input types.
/// </summary>
public class GenerativeModel {
private FirebaseApp _firebaseApp;

// Various setting fields provided by the user.
private string _location;
private string _modelName;
private GenerationConfig? _generationConfig;
private SafetySetting[] _safetySettings;
private Tool[] _tools;
private ToolConfig? _toolConfig;
private ModelContent? _systemInstruction;
private RequestOptions? _requestOptions;

HttpClient _httpClient;

internal GenerativeModel(FirebaseApp firebaseApp,
string location,
string modelName,
GenerationConfig? generationConfig = null,
SafetySetting[] safetySettings = null,
Tool[] tools = null,
ToolConfig? toolConfig = null,
ModelContent? systemInstruction = null,
RequestOptions? requestOptions = null) {
_firebaseApp = firebaseApp;
_location = location;
_modelName = modelName;
_generationConfig = generationConfig;
_safetySettings = safetySettings;
_tools = tools;
_toolConfig = toolConfig;
_systemInstruction = systemInstruction;
_requestOptions = requestOptions;

// Create a HttpClient using the timeout requested, or the default one.
_httpClient = new HttpClient() {
Timeout = requestOptions?.Timeout ?? RequestOptions.DefaultTimeout
};
}

#region Public API
/// <summary>
/// Generates new content from input `ModelContent` given to the model as a prompt.
/// </summary>
/// <param name="content">The input(s) given to the model as a prompt.</param>
/// <returns>The generated content response from the model.</returns>
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
public Task<GenerateContentResponse> GenerateContentAsync(
params ModelContent[] content) {
throw new NotImplementedException();
return GenerateContentAsync((IEnumerable<ModelContent>)content);
}
/// <summary>
/// Generates new content from input text given to the model as a prompt.
/// </summary>
/// <param name="content">The text given to the model as a prompt.</param>
/// <returns>The generated content response from the model.</returns>
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
public Task<GenerateContentResponse> GenerateContentAsync(
IEnumerable<ModelContent> content) {
throw new NotImplementedException();
string text) {
return GenerateContentAsync(new ModelContent[] { ModelContent.Text(text) });
}
/// <summary>
/// Generates new content from input `ModelContent` given to the model as a prompt.
/// </summary>
/// <param name="content">The input(s) given to the model as a prompt.</param>
/// <returns>The generated content response from the model.</returns>
/// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
public Task<GenerateContentResponse> GenerateContentAsync(
string text) {
throw new NotImplementedException();
IEnumerable<ModelContent> content) {
return GenerateContentAsyncInternal(content);
}

// The build logic isn't able to resolve IAsyncEnumerable for some reason, even
// though it is usable in Unity 2021.3. Will need to investigate further.
/*
#if !HIDE_IASYNCENUMERABLE
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
params ModelContent[] content) {
throw new NotImplementedException();
return GenerateContentStreamAsync((IEnumerable<ModelContent>)content);
}
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
IEnumerable<ModelContent> content) {
throw new NotImplementedException();
string text) {
return GenerateContentStreamAsync(new ModelContent[] { ModelContent.Text(text) });
}
public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
string text) {
throw new NotImplementedException();
IEnumerable<ModelContent> content) {
return GenerateContentStreamAsyncInternal(content);
}
*/
#endif

public Task<CountTokensResponse> CountTokensAsync(
params ModelContent[] content) {
throw new NotImplementedException();
return CountTokensAsync((IEnumerable<ModelContent>)content);
}
public Task<CountTokensResponse> CountTokensAsync(
IEnumerable<ModelContent> content) {
throw new NotImplementedException();
string text) {
return CountTokensAsync(new ModelContent[] { ModelContent.Text(text) });
}
public Task<CountTokensResponse> CountTokensAsync(
string text) {
throw new NotImplementedException();
IEnumerable<ModelContent> content) {
return CountTokensAsyncInternal(content);
}

public Chat StartChat(params ModelContent[] history) {
throw new NotImplementedException();
return StartChat((IEnumerable<ModelContent>)history);
}
public Chat StartChat(IEnumerable<ModelContent> history) {
// TODO: Implementation
throw new NotImplementedException();
}
#endregion

private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
IEnumerable<ModelContent> content) {
string bodyJson = ModelContentsToJson(content);

HttpRequestMessage request = new(HttpMethod.Post, GetURL() + ":generateContent");

// Set the request headers
request.Headers.Add("x-goog-api-key", _firebaseApp.Options.ApiKey);
request.Headers.Add("x-goog-api-client", "genai-csharp/0.1.0");

// Set the content
request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json");

HttpResponseMessage response = await _httpClient.SendAsync(request);
// TODO: Convert any timeout exception into a VertexAI equivalent
// TODO: Convert any HttpRequestExceptions, see:
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
// https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
response.EnsureSuccessStatusCode();

string result = await response.Content.ReadAsStringAsync();

return GenerateContentResponse.FromJson(result);
}

// Note: No public constructor, get one through VertexAI.GetGenerativeModel
#if !HIDE_IASYNCENUMERABLE
private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsyncInternal(
IEnumerable<ModelContent> content) {
// TODO: Implementation
await Task.CompletedTask;
yield return new GenerateContentResponse();
throw new NotImplementedException();
}
#endif

private async Task<CountTokensResponse> CountTokensAsyncInternal(
IEnumerable<ModelContent> content) {
// TODO: Implementation
await Task.CompletedTask;
throw new NotImplementedException();
}

private string GetURL() {
return "https://firebaseml.googleapis.com/v2beta" +
"/projects/" + _firebaseApp.Options.ProjectId +
"/locations/" + _location +
"/publishers/google/models/" + _modelName;
}

private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
Dictionary<string, object> jsonDict = new() {
// Convert the Contents into a list of Json dictionaries
["contents"] = contents.Select(c => c.ToJson()).ToList()
};
// TODO: All the other settings

return Json.Serialize(jsonDict);
}
}

}
Loading

0 comments on commit f5863f8

Please sign in to comment.