Skip to content

Commit

Permalink
Add usageMetadata to GenerateContentResponse (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored May 3, 2024
1 parent 16e68be commit f94bdba
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 27 deletions.
38 changes: 37 additions & 1 deletion Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
Expand Down Expand Up @@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

Expand Down Expand Up @@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
Expand All @@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
data: {"candidates": [{"content": {"parts": [{"text": "Cheyenne"}]},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}}

data: {"candidates": [{"content": {"parts": [{"text": "Mountain View, California"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"candidatesTokenCount": 4}}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"content": {
"parts": [
{
"text": "Mountain View, California, United States"
"text": "Mountain View, California"
}
],
"role": "model"
Expand All @@ -31,24 +31,7 @@
]
}
],
"promptFeedback": {
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
}
]
"usageMetadata": {
"candidatesTokenCount": 4
}
}
49 changes: 45 additions & 4 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
XCTAssertEqual(part.text, "Mountain View, California, United States")
XCTAssertEqual(part.text, "Mountain View, California")
XCTAssertEqual(response.text, part.text)
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertNil(promptFeedback.blockReason)
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
XCTAssertNil(response.promptFeedback)
XCTAssertEqual(response.functionCalls, [])
}

Expand Down Expand Up @@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.functionCalls, [functionCall])
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase {
}))
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down

0 comments on commit f94bdba

Please sign in to comment.