diff --git a/FirebaseAI/Sources/APIMethod.swift b/FirebaseAI/Sources/APIMethod.swift new file mode 100644 index 00000000000..9afa9c163aa --- /dev/null +++ b/FirebaseAI/Sources/APIMethod.swift @@ -0,0 +1,21 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum APIMethod: String { + case generateContent + case streamGenerateContent + case countTokens +} diff --git a/FirebaseAI/Sources/Chat.swift b/FirebaseAI/Sources/Chat.swift index 80e908a8f57..99c6fb13367 100644 --- a/FirebaseAI/Sources/Chat.swift +++ b/FirebaseAI/Sources/Chat.swift @@ -19,35 +19,21 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public final class Chat: Sendable { private let model: GenerativeModel + private let _history: History - /// Initializes a new chat representing a 1:1 conversation between model and user. init(model: GenerativeModel, history: [ModelContent]) { self.model = model - self.history = history + _history = History(history: history) } - private let historyLock = NSLock() - private nonisolated(unsafe) var _history: [ModelContent] = [] /// The previous content from the chat that has been successfully sent and received from the /// model. This will be provided to the model for each message sent as context for the discussion. public var history: [ModelContent] { get { - historyLock.withLock { _history } + return _history.history } set { - historyLock.withLock { _history = newValue } - } - } - - private func appendHistory(contentsOf: [ModelContent]) { - historyLock.withLock { - _history.append(contentsOf: contentsOf) - } - } - - private func appendHistory(_ newElement: ModelContent) { - historyLock.withLock { - _history.append(newElement) + _history.history = newValue } } @@ -87,8 +73,8 @@ public final class Chat: Sendable { let toAdd = ModelContent(role: "model", parts: reply.parts) // Append the request and successful result to history, then return the value. - appendHistory(contentsOf: newContent) - appendHistory(toAdd) + _history.append(contentsOf: newContent) + _history.append(toAdd) return result } @@ -136,63 +122,16 @@ public final class Chat: Sendable { } // Save the request. - appendHistory(contentsOf: newContent) + _history.append(contentsOf: newContent) // Aggregate the content to add it to the history before we finish. - let aggregated = self.aggregatedChunks(aggregatedContent) - self.appendHistory(aggregated) + let aggregated = self._history.aggregatedChunks(aggregatedContent) + self._history.append(aggregated) continuation.finish() } } } - private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { - var parts: [InternalPart] = [] - var combinedText = "" - var combinedThoughts = "" - - func flush() { - if !combinedThoughts.isEmpty { - parts.append(InternalPart(.text(combinedThoughts), isThought: true, thoughtSignature: nil)) - combinedThoughts = "" - } - if !combinedText.isEmpty { - parts.append(InternalPart(.text(combinedText), isThought: nil, thoughtSignature: nil)) - combinedText = "" - } - } - - // Loop through all the parts, aggregating the text. - for part in chunks.flatMap({ $0.internalParts }) { - // Only text parts may be combined. - if case let .text(text) = part.data, part.thoughtSignature == nil { - // Thought summaries must not be combined with regular text. - if part.isThought ?? false { - // If we were combining regular text, flush it before handling "thoughts". - if !combinedText.isEmpty { - flush() - } - combinedThoughts += text - } else { - // If we were combining "thoughts", flush it before handling regular text. - if !combinedThoughts.isEmpty { - flush() - } - combinedText += text - } - } else { - // This is a non-combinable part (not text), flush any pending text. - flush() - parts.append(part) - } - } - - // Flush any remaining text. - flush() - - return ModelContent(role: "model", parts: parts) - } - /// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions. private func populateContentRole(_ content: ModelContent) -> ModelContent { if content.role != nil { diff --git a/FirebaseAI/Sources/FirebaseAI.swift b/FirebaseAI/Sources/FirebaseAI.swift index f9ff5ea0424..2c041a876f1 100644 --- a/FirebaseAI/Sources/FirebaseAI.swift +++ b/FirebaseAI/Sources/FirebaseAI.swift @@ -136,6 +136,28 @@ public final class FirebaseAI: Sendable { ) } + /// Initializes a new `TemplateGenerativeModel`. + /// + /// - Returns: A new `TemplateGenerativeModel` instance. + public func templateGenerativeModel() -> TemplateGenerativeModel { + return TemplateGenerativeModel( + generativeAIService: GenerativeAIService(firebaseInfo: firebaseInfo, + urlSession: GenAIURLSession.default), + apiConfig: apiConfig + ) + } + + /// Initializes a new `TemplateImagenModel`. + /// + /// - Returns: A new `TemplateImagenModel` instance. + public func templateImagenModel() -> TemplateImagenModel { + return TemplateImagenModel( + generativeAIService: GenerativeAIService(firebaseInfo: firebaseInfo, + urlSession: GenAIURLSession.default), + apiConfig: apiConfig + ) + } + /// **[Public Preview]** Initializes a ``LiveGenerativeModel`` with the given parameters. /// /// > Warning: Using the Firebase AI Logic SDKs with the Gemini Live API is in Public diff --git a/FirebaseAI/Sources/GenerateContentRequest.swift b/FirebaseAI/Sources/GenerateContentRequest.swift index 21acd502a75..ddf8d7d28cb 100644 --- a/FirebaseAI/Sources/GenerateContentRequest.swift +++ b/FirebaseAI/Sources/GenerateContentRequest.swift @@ -60,15 +60,6 @@ extension GenerateContentRequest: Encodable { } } -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension GenerateContentRequest { - enum APIMethod: String { - case generateContent - case streamGenerateContent - case countTokens - } -} - @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) extension GenerateContentRequest: GenerativeAIRequest { typealias Response = GenerateContentResponse diff --git a/FirebaseAI/Sources/GenerateImagesRequest.swift b/FirebaseAI/Sources/GenerateImagesRequest.swift new file mode 100644 index 00000000000..80f2dbbf8a5 --- /dev/null +++ b/FirebaseAI/Sources/GenerateImagesRequest.swift @@ -0,0 +1,58 @@ + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public class GenerateImagesRequest: @unchecked Sendable, GenerativeAIRequest { + public typealias Response = ImagenGenerationResponse + + public var url: URL { + var urlString = + "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)" + if case let .vertexAI(_, location) = apiConfig.service { + urlString += "/locations/\(location)" + } + let templateName = template.hasSuffix(".prompt") ? template : "\(template).prompt" + urlString += "/templates/\(templateName):\(ImageAPIMethod.generateImages.rawValue)" + return URL(string: urlString)! + } + + public let options: RequestOptions + + let apiConfig: APIConfig + + let template: String + let variables: [String: TemplateVariable] + let projectID: String + + init(template: String, variables: [String: TemplateVariable], projectID: String, + apiConfig: APIConfig, options: RequestOptions) { + self.apiConfig = apiConfig + self.options = options + self.template = template + self.variables = variables + self.projectID = projectID + } + + enum CodingKeys: String, CodingKey { + case variables = "inputs" + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(variables, forKey: .variables) + } +} diff --git a/FirebaseAI/Sources/GenerativeAIService.swift b/FirebaseAI/Sources/GenerativeAIService.swift index a17364f8cb6..8eeef564e4c 100644 --- a/FirebaseAI/Sources/GenerativeAIService.swift +++ b/FirebaseAI/Sources/GenerativeAIService.swift @@ -26,7 +26,7 @@ struct GenerativeAIService { /// The Firebase SDK version in the format `fire/`. static let firebaseVersionTag = "fire/\(FirebaseVersion())" - private let firebaseInfo: FirebaseInfo + let firebaseInfo: FirebaseInfo private let urlSession: URLSession diff --git a/FirebaseAI/Sources/History.swift b/FirebaseAI/Sources/History.swift new file mode 100644 index 00000000000..827f7df5b46 --- /dev/null +++ b/FirebaseAI/Sources/History.swift @@ -0,0 +1,94 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class History: Sendable { + private let historyLock = NSLock() + private nonisolated(unsafe) var _history: [ModelContent] = [] + /// The previous content from the chat that has been successfully sent and received from the + /// model. This will be provided to the model for each message sent as context for the discussion. + public var history: [ModelContent] { + get { + historyLock.withLock { _history } + } + set { + historyLock.withLock { _history = newValue } + } + } + + init(history: [ModelContent]) { + self.history = history + } + + func append(contentsOf: [ModelContent]) { + historyLock.withLock { + _history.append(contentsOf: contentsOf) + } + } + + func append(_ newElement: ModelContent) { + historyLock.withLock { + _history.append(newElement) + } + } + + func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { + var parts: [InternalPart] = [] + var combinedText = "" + var combinedThoughts = "" + + func flush() { + if !combinedThoughts.isEmpty { + parts.append(InternalPart(.text(combinedThoughts), isThought: true, thoughtSignature: nil)) + combinedThoughts = "" + } + if !combinedText.isEmpty { + parts.append(InternalPart(.text(combinedText), isThought: nil, thoughtSignature: nil)) + combinedText = "" + } + } + + // Loop through all the parts, aggregating the text. + for part in chunks.flatMap({ $0.internalParts }) { + // Only text parts may be combined. + if case let .text(text) = part.data, part.thoughtSignature == nil { + // Thought summaries must not be combined with regular text. + if part.isThought ?? false { + // If we were combining regular text, flush it before handling "thoughts". + if !combinedText.isEmpty { + flush() + } + combinedThoughts += text + } else { + // If we were combining "thoughts", flush it before handling regular text. + if !combinedThoughts.isEmpty { + flush() + } + combinedText += text + } + } else { + // This is a non-combinable part (not text), flush any pending text. + flush() + parts.append(part) + } + } + + // Flush any remaining text. + flush() + + return ModelContent(role: "model", parts: parts) + } +} diff --git a/FirebaseAI/Sources/ImageAPIMethod.swift b/FirebaseAI/Sources/ImageAPIMethod.swift new file mode 100644 index 00000000000..bf8a6c367d1 --- /dev/null +++ b/FirebaseAI/Sources/ImageAPIMethod.swift @@ -0,0 +1,18 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum ImageAPIMethod: String { + case generateImages = "templatePredict" +} diff --git a/FirebaseAI/Sources/TemplateChatSession.swift b/FirebaseAI/Sources/TemplateChatSession.swift new file mode 100644 index 00000000000..e6e705f5814 --- /dev/null +++ b/FirebaseAI/Sources/TemplateChatSession.swift @@ -0,0 +1,109 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A chat session that allows for conversation with a model. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class TemplateChatSession: Sendable { + private let model: TemplateGenerativeModel + private let template: String + private let _history: History + + init(model: TemplateGenerativeModel, template: String, history: [ModelContent]) { + self.model = model + self.template = template + _history = History(history: history) + } + + public var history: [ModelContent] { + get { + return _history.history + } + set { + _history.history = newValue + } + } + + /// Sends a message to the model and returns the response. + public func sendMessage(_ message: any PartsRepresentable, + variables: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let templateVariables = try variables.mapValues { try TemplateVariable(value: $0) } + let newContent = populateContentRole(ModelContent(parts: message.partsValue)) + let response = try await model.generateContentWithHistory( + history: _history.history + [newContent], + template: template, + variables: templateVariables, + options: options + ) + _history.append(newContent) + if let modelResponse = response.candidates.first { + _history.append(modelResponse.content) + } + return response + } + + public func sendMessageStream(_ message: any PartsRepresentable, + variables: [String: Any], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let templateVariables = try variables.mapValues { try TemplateVariable(value: $0) } + let newContent = populateContentRole(ModelContent(parts: message.partsValue)) + let stream = try model.generateContentStreamWithHistory( + history: _history.history + [newContent], + template: template, + variables: templateVariables, + options: options + ) + return AsyncThrowingStream { continuation in + Task { + var aggregatedContent: [ModelContent] = [] + + do { + for try await chunk in stream { + // Capture any content that's streaming. This should be populated if there's no error. + if let chunkContent = chunk.candidates.first?.content { + aggregatedContent.append(chunkContent) + } + + // Pass along the chunk. + continuation.yield(chunk) + } + } catch { + // Rethrow the error that the underlying stream threw. Don't add anything to history. + continuation.finish(throwing: error) + return + } + + // Save the request. + _history.append(newContent) + + // Aggregate the content to add it to the history before we finish. + let aggregated = _history.aggregatedChunks(aggregatedContent) + _history.append(aggregated) + continuation.finish() + } + } + } + + private func populateContentRole(_ content: ModelContent) -> ModelContent { + if content.role != nil { + return content + } else { + return ModelContent(role: "user", parts: content.parts) + } + } +} diff --git a/FirebaseAI/Sources/TemplateGenerateContentRequest.swift b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift new file mode 100644 index 00000000000..65e4708e96b --- /dev/null +++ b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift @@ -0,0 +1,59 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct TemplateGenerateContentRequest: Sendable { + let template: String + let variables: [String: TemplateVariable] + let history: [ModelContent] + let projectID: String + let stream: Bool + let apiConfig: APIConfig + let options: RequestOptions +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateGenerateContentRequest: Encodable { + enum CodingKeys: String, CodingKey { + case variables = "inputs" + case history + } + + func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(variables, forKey: .variables) + try container.encode(history, forKey: .history) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateGenerateContentRequest: GenerativeAIRequest { + typealias Response = GenerateContentResponse + + var url: URL { + var urlString = + "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)" + if case let .vertexAI(_, location) = apiConfig.service { + urlString += "/locations/\(location)" + } + let templateName = template.hasSuffix(".prompt") ? template : "\(template).prompt" + urlString += "/templates/\(templateName):templateGenerateContent" + if stream { + urlString += "?alt=sse" + } + return URL(string: urlString)! + } +} diff --git a/FirebaseAI/Sources/TemplateGenerativeModel.swift b/FirebaseAI/Sources/TemplateGenerativeModel.swift new file mode 100644 index 00000000000..24b6d638eee --- /dev/null +++ b/FirebaseAI/Sources/TemplateGenerativeModel.swift @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that represents a remote multimodal model (like Gemini), with the ability to generate +/// content based on various input types. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class TemplateGenerativeModel: Sendable { + let generativeAIService: GenerativeAIService + let apiConfig: APIConfig + + init(generativeAIService: GenerativeAIService, apiConfig: APIConfig) { + self.generativeAIService = generativeAIService + self.apiConfig = apiConfig + } + + /// Generates content from a prompt template and variables. + /// + /// - Parameters: + /// - template: The prompt template to use. + /// - variables: A dictionary of variables to substitute into the template. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContent(template: String, + variables: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let templateVariables = try variables.mapValues { try TemplateVariable(value: $0) } + return try await generateContentWithHistory( + history: [], + template: template, + variables: templateVariables, + options: options + ) + } + + /// Generates content from a prompt template, variables, and history. + /// + /// - Parameters: + /// - history: The conversation history to use. + /// - template: The prompt template to use. + /// - variables: A dictionary of variables to substitute into the template. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + func generateContentWithHistory(history: [ModelContent], template: String, + variables: [String: TemplateVariable], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let request = TemplateGenerateContentRequest( + template: template, + variables: variables, + history: history, + projectID: generativeAIService.firebaseInfo.projectID, + stream: false, + apiConfig: apiConfig, + options: options + ) + let response: GenerateContentResponse = try await generativeAIService + .loadRequest(request: request) + return response + } + + public func generateContentStream(template: String, + variables: [String: Any], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let templateVariables = try variables.mapValues { try TemplateVariable(value: $0) } + let request = TemplateGenerateContentRequest( + template: template, + variables: templateVariables, + history: [], + projectID: generativeAIService.firebaseInfo.projectID, + stream: true, + apiConfig: apiConfig, + options: options + ) + return generativeAIService.loadRequestStream(request: request) + } + + func generateContentStreamWithHistory(history: [ModelContent], template: String, + variables: [String: TemplateVariable], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let request = TemplateGenerateContentRequest( + template: template, + variables: variables, + history: history, + projectID: generativeAIService.firebaseInfo.projectID, + stream: true, + apiConfig: apiConfig, + options: options + ) + return generativeAIService.loadRequestStream(request: request) + } + + /// Creates a new chat conversation using this model with the provided history and template. + /// + /// - Parameters: + /// - template: The prompt template to use. + /// - history: The conversation history to use. + /// - Returns: A new ``TemplateChatSession`` instance. + public func startChat(template: String, + history: [ModelContent] = []) -> TemplateChatSession { + return TemplateChatSession( + model: self, + template: template, + history: history + ) + } +} diff --git a/FirebaseAI/Sources/TemplateImagenModel.swift b/FirebaseAI/Sources/TemplateImagenModel.swift new file mode 100644 index 00000000000..0e37d1689b1 --- /dev/null +++ b/FirebaseAI/Sources/TemplateImagenModel.swift @@ -0,0 +1,52 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that represents a remote image generation model (like Imagen), with the ability to +/// generate +/// images based on various input types. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class TemplateImagenModel: Sendable { + let generativeAIService: GenerativeAIService + let apiConfig: APIConfig + + init(generativeAIService: GenerativeAIService, apiConfig: APIConfig) { + self.generativeAIService = generativeAIService + self.apiConfig = apiConfig + } + + /// Generates images from a prompt template and variables. + /// + /// - Parameters: + /// - template: The prompt template to use. + /// - variables: A dictionary of variables to substitute into the template. + /// - Returns: The images generated by the model. + /// - Throws: An error if the request failed. + public func generateImages(template: String, + variables: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> ImagenGenerationResponse { + let templateVariables = try variables.mapValues { try TemplateVariable(value: $0) } + let projectID = generativeAIService.firebaseInfo.projectID + let request = GenerateImagesRequest( + template: template, + variables: templateVariables, + projectID: projectID, + apiConfig: apiConfig, + options: options + ) + return try await generativeAIService.loadRequest(request: request) + } +} diff --git a/FirebaseAI/Sources/TemplateVariable.swift b/FirebaseAI/Sources/TemplateVariable.swift new file mode 100644 index 00000000000..a942c2c715b --- /dev/null +++ b/FirebaseAI/Sources/TemplateVariable.swift @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum TemplateVariable: Encodable, Sendable { + case string(String) + case int(Int) + case double(Double) + case bool(Bool) + case array([TemplateVariable]) + case dictionary([String: TemplateVariable]) + + init(value: Any) throws { + switch value { + case let value as String: + self = .string(value) + case let value as Int: + self = .int(value) + case let value as Double: + self = .double(value) + case let value as Float: + self = .double(Double(value)) + case let value as Bool: + self = .bool(value) + case let value as [Any]: + self = try .array(value.map { try TemplateVariable(value: $0) }) + case let value as [String: Any]: + self = try .dictionary(value.mapValues { try TemplateVariable(value: $0) }) + default: + throw EncodingError.invalidValue( + value, + EncodingError.Context(codingPath: [], debugDescription: "Invalid value") + ) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .int(value): + try container.encode(value) + case let .double(value): + try container.encode(value) + case let .bool(value): + try container.encode(value) + case let .array(value): + try container.encode(value) + case let .dictionary(value): + try container.encode(value) + } + } +} diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift new file mode 100644 index 00000000000..69b6d1add3a --- /dev/null +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift @@ -0,0 +1,163 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +import FirebaseAI + +final class ServerPromptTemplateIntegrationTests: XCTestCase { + override func setUp() { + super.setUp() + continueAfterFailure = false + } + + func testGenerateContentWithText() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let userName = "paul" + let response = try await model.generateContent( + template: "greeting", + variables: [ + "name": userName, + "language": "Spanish", + ] + ) + let text = try XCTUnwrap(response.text) + XCTAssert(text.contains("Paul")) + } + + func testGenerateContentStream() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let userName = "paul" + let stream = try model.generateContentStream( + template: "greeting", + variables: [ + "name": userName, + "language": "English", + ] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + XCTAssert(resultText.contains("Paul")) + } + + func testGenerateImages() async throws { + let imagenModel = FirebaseAI.firebaseAI(backend: .vertexAI()).templateImagenModel() + let imagenPrompt = "A cat picture" + let response = try await imagenModel.generateImages( + template: "generate_images", + variables: [ + "prompt": imagenPrompt, + ] + ) + XCTAssertEqual(response.images.count, 3) + } + + func testGenerateContentWithMedia() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let image = UIImage(systemName: "photo")! + if let imageBytes = image.jpegData(compressionQuality: 0.8) { + let base64Image = imageBytes.base64EncodedString() + + let response = try await model.generateContent( + template: "media", + variables: [ + "imageData": [ + "isInline": true, + "mimeType": "image/jpeg", + "contents": base64Image, + ], + ] + ) + XCTAssert(response.text?.isEmpty == false) + } else { + XCTFail("Could not get image data.") + } + } + + func testGenerateContentStreamWithMedia() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let image = UIImage(systemName: "photo")! + if let imageBytes = image.jpegData(compressionQuality: 0.8) { + let base64Image = imageBytes.base64EncodedString() + + let stream = try model.generateContentStream( + template: "media", + variables: [ + "imageData": [ + "isInline": true, + "mimeType": "image/jpeg", + "contents": base64Image, + ], + ] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + XCTAssert(resultText.isEmpty == false) + } else { + XCTFail("Could not get image data.") + } + } + + func testChat() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let initialHistory = [ + ModelContent(role: "user", parts: "Hello!"), + ModelContent(role: "model", parts: "Hi there! How can I help?"), + ] + let chatSession = model.startChat(template: "chat_history", history: initialHistory) + + let userMessage = "What's the weather like?" + + let response = try await chatSession.sendMessage( + userMessage, + variables: ["message": userMessage] + ) + XCTAssert(response.text?.isEmpty == false) + XCTAssertEqual(chatSession.history.count, 4) + XCTAssertEqual((chatSession.history[2].parts.first as? TextPart)?.text, userMessage) + } + + func testChatStream() async throws { + let model = FirebaseAI.firebaseAI(backend: .vertexAI()).templateGenerativeModel() + let initialHistory = [ + ModelContent(role: "user", parts: "Hello!"), + ModelContent(role: "model", parts: "Hi there! How can I help?"), + ] + let chatSession = model.startChat(template: "chat_history", history: initialHistory) + + let userMessage = "What's the weather like?" + + let stream = try chatSession.sendMessageStream( + userMessage, + variables: ["message": userMessage] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + XCTAssert(resultText.isEmpty == false) + XCTAssertEqual(chatSession.history.count, 4) + XCTAssertEqual((chatSession.history[2].parts.first as? TextPart)?.text, userMessage) + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift b/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift new file mode 100644 index 00000000000..28278549f4d --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAI +import FirebaseCore +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateChatSessionTests: XCTestCase { + var model: TemplateGenerativeModel! + var urlSession: URLSession! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateGenerativeModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testSendMessage() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(template: "test-template") + let response = try await chat.sendMessage("Hello", variables: ["name": "test"]) + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + XCTAssertEqual( + (chat.history[1].parts.first as? TextPart)?.text, + "Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n" + ) + XCTAssertEqual(response.candidates.count, 1) + } + + func testSendMessageStream() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(template: "test-template") + let stream = try chat.sendMessageStream("Hello", variables: ["name": "test"]) + + let content = try await GenerativeModelTestUtil.collectTextFromStream(stream) + + XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n") + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift b/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift new file mode 100644 index 00000000000..b191ccf54cb --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAI +import FirebaseCore +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateGenerativeModelTests: XCTestCase { + var urlSession: URLSession! + var model: TemplateGenerativeModel! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateGenerativeModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testGenerateContent() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + + let response = try await model.generateContent( + template: "test-template", + variables: ["name": "test"] + ) + XCTAssertEqual( + response.text, + "Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n" + ) + } + + func testGenerateContentStream() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + + let stream = try model.generateContentStream( + template: "test-template", + variables: ["name": "test"] + ) + + let content = try await GenerativeModelTestUtil.collectTextFromStream(stream) + XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n") + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift b/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift new file mode 100644 index 00000000000..f350d4ae980 --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift @@ -0,0 +1,52 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law of or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAI +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateImagenModelTests: XCTestCase { + var urlSession: URLSession! + var model: TemplateImagenModel! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateImagenModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testGenerateImages() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-generate-images-base64", + withExtension: "json", + subdirectory: "mock-responses/vertexai", + isTemplateRequest: true + ) + + let response = try await model.generateImages( + template: "test-template", + variables: ["prompt": "a cat picture"] + ) + XCTAssertEqual(response.images.count, 4) + XCTAssertNotNil(response.images.first?.data) + } +} diff --git a/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift b/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift index ee4f47bc5b0..8ef28d33114 100644 --- a/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift +++ b/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift @@ -30,10 +30,12 @@ enum GenerativeModelTestUtil { timeout: TimeInterval = RequestOptions().timeout, appCheckToken: String? = nil, authToken: String? = nil, - dataCollection: Bool = true) throws -> ((URLRequest) throws -> ( - URLResponse, - AsyncLineSequence? - )) { + dataCollection: Bool = true, + isTemplateRequest: Bool = false) throws + -> ((URLRequest) throws -> ( + URLResponse, + AsyncLineSequence? + )) { // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see // https://developer.apple.com/documentation/foundation/urlprotocol for details. #if os(watchOS) @@ -45,7 +47,14 @@ enum GenerativeModelTestUtil { ) return { request in let requestURL = try XCTUnwrap(request.url) - XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) + if isTemplateRequest { + XCTAssertEqual( + requestURL.path.occurrenceCount(of: "templates/test-template.prompt:template"), + 1 + ) + } else { + XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) + } XCTAssertEqual(request.timeoutInterval, timeout) let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client")) .components(separatedBy: " ") @@ -79,6 +88,19 @@ enum GenerativeModelTestUtil { #endif // os(watchOS) } + static func collectTextFromStream(_ stream: AsyncThrowingStream< + GenerateContentResponse, + Error + >) async throws -> String { + var content = "" + for try await response in stream { + if let text = response.text { + content += text + } + } + return content + } + static func nonHTTPRequestHandler() throws -> ((URLRequest) -> ( URLResponse, AsyncLineSequence?