diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 4b01522..683df0c 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -17,6 +17,18 @@ import Foundation /// The model's response to a generate content request. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public struct GenerateContentResponse { + /// Token usage metadata for processing the generate content request. + public struct UsageMetadata { + /// The number of tokens in the request prompt. + public let promptTokenCount: Int + + /// The total number of tokens across the generated response candidates. + public let candidatesTokenCount: Int + + /// The total number of tokens in both the request and response. + public let totalTokenCount: Int + } + /// A list of candidate response content, ordered from best to worst. public let candidates: [CandidateResponse] @@ -24,6 +36,9 @@ public struct GenerateContentResponse { /// reason for blocking the request. public let promptFeedback: PromptFeedback? + /// Token usage metadata for processing the generate content request. + public let usageMetadata: UsageMetadata? + /// The response's content as text, if it exists. public var text: String? { guard let candidate = candidates.first else { @@ -51,9 +66,11 @@ public struct GenerateContentResponse { } /// Initializer for SwiftUI previews or tests. - public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) { + public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil, + usageMetadata: UsageMetadata? = nil) { self.candidates = candidates self.promptFeedback = promptFeedback + self.usageMetadata = usageMetadata } } @@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable { enum CodingKeys: CodingKey { case candidates case promptFeedback + case usageMetadata } public init(from decoder: Decoder) throws { @@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable { candidates = [] } promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback) + usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata) + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentResponse.UsageMetadata: Decodable { + enum CodingKeys: CodingKey { + case promptTokenCount + case candidatesTokenCount + case totalTokenCount + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0 + candidatesTokenCount = try container + .decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0 + totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0 } } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 1ce8816..9ed3401 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible) XCTAssertEqual(candidate.content.parts.count, 1) let part = try XCTUnwrap(candidate.content.parts.first) - XCTAssertEqual(part.text, "Mountain View, California, United States") + XCTAssertEqual(part.text, "Mountain View, California") XCTAssertEqual(response.text, part.text) - let promptFeedback = try XCTUnwrap(response.promptFeedback) - XCTAssertNil(promptFeedback.blockReason) - XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible) + XCTAssertNil(response.promptFeedback) XCTAssertEqual(response.functionCalls, []) } @@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.functionCalls, [functionCall]) } + func testGenerateContent_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + let usageMetadata = try XCTUnwrap(response.usageMetadata) + // TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added. + XCTAssertEqual(usageMetadata.promptTokenCount, 0) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 4) + XCTAssertEqual(usageMetadata.totalTokenCount, 0) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol @@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase { })) } + func testGenerateContentStream_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt" + ) + var responses = [GenerateContentResponse]() + + let stream = model.generateContentStream(testPrompt) + for try await response in stream { + responses.append(response) + } + + for (index, response) in responses.enumerated() { + if index == responses.endIndex - 1 { + let usageMetadata = try XCTUnwrap(response.usageMetadata) + // TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added. + XCTAssertEqual(usageMetadata.promptTokenCount, 0) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 4) + XCTAssertEqual(usageMetadata.totalTokenCount, 0) + } else { + // Only the last streamed response contains usage metadata + XCTAssertNil(response.usageMetadata) + } + } + } + func testGenerateContentStream_errorMidStream() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "streaming-failure-error-mid-stream",