Skip to content

Commit

Permalink
[Vertex AI] Use struct instead of enum for HarmProbability (fir…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored and MojtabaHs committed Oct 17, 2024
1 parent e0d8a74 commit 56ff8a8
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 31 deletions.
7 changes: 4 additions & 3 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
as input. (#13767)
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
the label `parts: `. (#13832)
- [changed] **Breaking Change**: `HarmCategory` is now a struct instead of an
enum type and the `unknown` case has been removed; in a `switch` statement,
use the `default:` case to cover unknown or unhandled categories. (#13728)
- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now
structs instead of enums types and the `unknown` cases have been removed; in a
`switch` statement, use the `default:` case to cover unknown or unhandled
categories or probabilities. (#13728, #13854)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ private extension HarmCategory {
case .hateSpeech: "Hate speech"
case .sexuallyExplicit: "Sexually explicit"
case .civicIntegrity: "Civic integrity"
default:
"Unknown HarmCategory: \(rawValue)"
default: "Unknown HarmCategory: \(rawValue)"
}
}
}
Expand All @@ -39,7 +38,7 @@ private extension SafetyRating.HarmProbability {
case .low: "Low"
case .medium: "Medium"
case .negligible: "Negligible"
case .unknown: "Unknown"
default: "Unknown HarmProbability: \(rawValue)"
}
}
}
Expand Down
79 changes: 56 additions & 23 deletions FirebaseVertexAI/Sources/Safety.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,66 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
self.probability = probability
}

/// The probability that a given model output falls under a harmful content category. This does
/// not indicate the severity of harm for a piece of content.
public enum HarmProbability: String, Sendable {
/// Unknown. A new server value that isn't recognized by the SDK.
case unknown = "UNKNOWN"
/// The probability that a given model output falls under a harmful content category.
///
/// > Note: This does not indicate the severity of harm for a piece of content.
public struct HarmProbability: Sendable, Equatable, Hashable {
enum Kind: String {
case negligible = "NEGLIGIBLE"
case low = "LOW"
case medium = "MEDIUM"
case high = "HIGH"
}

/// The probability is zero or close to zero. For benign content, the probability across all
/// categories will be this value.
case negligible = "NEGLIGIBLE"
/// The probability is zero or close to zero.
///
/// For benign content, the probability across all categories will be this value.
public static var negligible: HarmProbability {
return self.init(kind: .negligible)
}

/// The probability is small but non-zero.
case low = "LOW"
public static var low: HarmProbability {
return self.init(kind: .low)
}

/// The probability is moderate.
case medium = "MEDIUM"
public static var medium: HarmProbability {
return self.init(kind: .medium)
}

/// The probability is high.
///
/// The content described is very likely harmful.
public static var high: HarmProbability {
return self.init(kind: .high)
}

/// Returns the raw string representation of the `HarmProbability` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#SafetyRating).
public let rawValue: String

/// The probability is high. The content described is very likely harmful.
case high = "HIGH"
init(kind: Kind) {
rawValue = kind.rawValue
}

init(rawValue: String) {
if Kind(rawValue: rawValue) == nil {
VertexLog.error(
code: .generateContentResponseUnrecognizedHarmProbability,
"""
Unrecognized HarmProbability with value "\(rawValue)":
- Check for updates to the SDK as support for "\(rawValue)" may have been added; see \
release notes at https://firebase.google.com/support/release-notes/ios
- Search for "\(rawValue)" in the Firebase Apple SDK Issue Tracker at \
https://github.com/firebase/firebase-ios-sdk/issues and file a Bug Report if none found
"""
)
}
self.rawValue = rawValue
}
}
}

Expand Down Expand Up @@ -163,17 +205,8 @@ public struct HarmCategory: Sendable, Equatable, Hashable {
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension SafetyRating.HarmProbability: Decodable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedProbability = SafetyRating.HarmProbability(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedHarmProbability,
"Unrecognized HarmProbability with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedProbability
let rawValue = try decoder.singleValueContainer().decode(String.self)
self = SafetyRating.HarmProbability(rawValue: rawValue)
}
}

Expand Down
7 changes: 5 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ final class GenerativeModelTests: XCTestCase {
func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
let expectedSafetyRatings = [
SafetyRating(category: .harassment, probability: .medium),
SafetyRating(category: .dangerousContent, probability: .unknown),
SafetyRating(
category: .dangerousContent,
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY")
),
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
]
MockURLProtocol
Expand Down Expand Up @@ -974,7 +977,7 @@ final class GenerativeModelTests: XCTestCase {
)
let unknownSafetyRating = SafetyRating(
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
probability: .unknown
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM")
)

var foundUnknownSafetyRating = false
Expand Down

0 comments on commit 56ff8a8

Please sign in to comment.