Skip to content

Commit

Permalink
Add usageMetadata to GenerateContentResponse (google-gemini#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored and G.Dev.Ssomsak committed Jun 21, 2024
1 parent 1401265 commit 49adf92
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
38 changes: 37 additions & 1 deletion Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
Expand Down Expand Up @@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

Expand Down Expand Up @@ -170,6 +187,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
Expand All @@ -194,6 +212,24 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}

Expand Down
49 changes: 45 additions & 4 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(candidate.safetyRatings, safetyRatingsNegligible)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
XCTAssertEqual(part.text, "Mountain View, California, United States")
XCTAssertEqual(part.text, "Mountain View, California")
XCTAssertEqual(response.text, part.text)
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertNil(promptFeedback.blockReason)
XCTAssertEqual(promptFeedback.safetyRatings, safetyRatingsNegligible)
XCTAssertNil(response.promptFeedback)
XCTAssertEqual(response.functionCalls, [])
}

Expand Down Expand Up @@ -256,6 +254,22 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.functionCalls, [functionCall])
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -756,6 +770,33 @@ final class GenerativeModelTests: XCTestCase {
}))
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
// TODO(andrewheard): Re-run prompt when `promptTokenCount` and `totalTokenCount` added.
XCTAssertEqual(usageMetadata.promptTokenCount, 0)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 0)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down

0 comments on commit 49adf92

Please sign in to comment.