From f94bdba98ca4d73cd07dad157f95663755a01dfb Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Fri, 3 May 2024 10:36:45 -0400 Subject: [PATCH] Add `usageMetadata` to `GenerateContentResponse` (#159) --- .../GoogleAI/GenerateContentResponse.swift | 38 +++++++++++++- .../streaming-success-basic-reply-short.txt | 3 +- .../unary-success-basic-reply-short.json | 23 ++------- .../GoogleAITests/GenerativeModelTests.swift | 49 +++++++++++++++++-- 4 files changed, 86 insertions(+), 27 deletions(-) diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 4b01522..683df0c 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -17,6 +17,18 @@ import Foundation /// The model's response to a generate content request. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public struct GenerateContentResponse { + /// Token usage metadata for processing the generate content request. + public struct UsageMetadata { + /// The number of tokens in the request prompt. + public let promptTokenCount: Int + + /// The total number of tokens across the generated response candidates. + public let candidatesTokenCount: Int + + /// The total number of tokens in both the request and response. + public let totalTokenCount: Int + } + /// A list of candidate response content, ordered from best to worst. public let candidates: [CandidateResponse] @@ -24,6 +36,9 @@ public struct GenerateContentResponse { /// reason for blocking the request. public let promptFeedback: PromptFeedback? + /// Token usage metadata for processing the generate content request. + public let usageMetadata: UsageMetadata? + /// The response's content as text, if it exists. public var text: String? { guard let candidate = candidates.first else { @@ -51,9 +66,11 @@ public struct GenerateContentResponse { } /// Initializer for SwiftUI previews or tests. - public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) { + public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil, + usageMetadata: UsageMetadata? = nil) { self.candidates = candidates self.promptFeedback = promptFeedback + self.usageMetadata = usageMetadata } } @@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable { enum CodingKeys: CodingKey { case candidates case promptFeedback + case usageMetadata } public init(from decoder: Decoder) throws { @@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable { candidates = [] } promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback) + usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata) + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentResponse.UsageMetadata: Decodable { + enum CodingKeys: CodingKey { + case promptTokenCount + case candidatesTokenCount + case totalTokenCount + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0 + candidatesTokenCount = try container + .decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0 + totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0 } } diff --git a/Tests/GoogleAITests/GenerateContentResponses/streaming-success-basic-reply-short.txt b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-basic-reply-short.txt index a7f5476..0060677 100644 --- a/Tests/GoogleAITests/GenerateContentResponses/streaming-success-basic-reply-short.txt +++ b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-basic-reply-short.txt @@ -1,2 +1 @@ -data: {"candidates": [{"content": {"parts": [{"text": "Cheyenne"}]},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"promptFeedback": {"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}} - +data: {"candidates": [{"content": {"parts": [{"text": "Mountain View, California"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"candidatesTokenCount": 4}} diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-basic-reply-short.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-basic-reply-short.json index 40a9a6d..dcfef87 100644 --- a/Tests/GoogleAITests/GenerateContentResponses/unary-success-basic-reply-short.json +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-basic-reply-short.json @@ -4,7 +4,7 @@ "content": { "parts": [ { - "text": "Mountain View, California, United States" + "text": "Mountain View, California" } ], "role": "model" @@ -31,24 +31,7 @@ ] } ], - "promptFeedback": { - "safetyRatings": [ - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_HARASSMENT", - "probability": "NEGLIGIBLE" - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "probability": "NEGLIGIBLE" - } - ] + "usageMetadata": { + "candidatesTokenCount": 4 } } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 1ce8816..9ed3401 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - XCTAssertEqual(part.text, "Mountain View, California, United States") + XCTAssertEqual(part.text, "Mountain View, California") XCTAssertEqual(response.text, part.text) - let promptFeedback = try XCTUnwrap(response.promptFeedback) - XCTAssertNil(promptFeedback.blockReason) - XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible) + XCTAssertNil(response.promptFeedback) XCTAssertEqual(response.functionCalls, []) } @@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.functionCalls, [functionCall]) } + func testGenerateContent_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + let usageMetadata = try XCTUnwrap(response.usageMetadata) + // TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added. + XCTAssertEqual(usageMetadata.promptTokenCount, 0) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 4) + XCTAssertEqual(usageMetadata.totalTokenCount, 0) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol @@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase { })) } + func testGenerateContentStream_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt" + ) + var responses = [GenerateContentResponse]() + + let stream = model.generateContentStream(testPrompt) + for try await response in stream { + responses.append(response) + } + + for (index, response) in responses.enumerated() { + if index == responses.endIndex - 1 { + let usageMetadata = try XCTUnwrap(response.usageMetadata) + // TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added. + XCTAssertEqual(usageMetadata.promptTokenCount, 0) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 4) + XCTAssertEqual(usageMetadata.totalTokenCount, 0) + } else { + // Only the last streamed response contains usage metadata + XCTAssertNil(response.usageMetadata) + } + } + } + func testGenerateContentStream_errorMidStream() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "streaming-failure-error-mid-stream",