diff --git a/firebase-functions/api.txt b/firebase-functions/api.txt index a9a05c703a8..c9d483b547c 100644 --- a/firebase-functions/api.txt +++ b/firebase-functions/api.txt @@ -82,6 +82,8 @@ package com.google.firebase.functions { public final class HttpsCallableReference { method public com.google.android.gms.tasks.Task call(); method public com.google.android.gms.tasks.Task call(Object? data); + method @NonNull public org.reactivestreams.Publisher stream(@Nullable Object data = null); + method @NonNull public org.reactivestreams.Publisher stream(); method public long getTimeout(); method public void setTimeout(long timeout, java.util.concurrent.TimeUnit units); method public com.google.firebase.functions.HttpsCallableReference withTimeout(long timeout, java.util.concurrent.TimeUnit units); diff --git a/firebase-functions/firebase-functions.gradle.kts b/firebase-functions/firebase-functions.gradle.kts index 7ec958bdd79..08a797112b9 100644 --- a/firebase-functions/firebase-functions.gradle.kts +++ b/firebase-functions/firebase-functions.gradle.kts @@ -112,6 +112,8 @@ dependencies { implementation(libs.okhttp) implementation(libs.playservices.base) implementation(libs.playservices.basement) + implementation(libs.reactive.streams) + api(libs.playservices.tasks) kapt(libs.autovalue) @@ -131,6 +133,7 @@ dependencies { androidTestImplementation(libs.truth) androidTestImplementation(libs.androidx.test.runner) androidTestImplementation(libs.androidx.test.junit) + androidTestImplementation(libs.kotlinx.coroutines.reactive) androidTestImplementation(libs.mockito.core) androidTestImplementation(libs.mockito.dexmaker) kapt("com.google.dagger:dagger-android-processor:2.43.2") diff --git a/firebase-functions/src/androidTest/backend/functions/index.js b/firebase-functions/src/androidTest/backend/functions/index.js index fed5a371b89..b55625fbd32 100644 --- a/firebase-functions/src/androidTest/backend/functions/index.js +++ b/firebase-functions/src/androidTest/backend/functions/index.js @@ -122,3 +122,58 @@ exports.timeoutTest = functions.https.onRequest((request, response) => { // Wait for longer than 500ms. setTimeout(() => response.send({data: true}), 500); }); + +const data = ["hello", "world", "this", "is", "cool"]; + +/** + * Pauses the execution for a specified amount of time. + * @param {number} ms - The number of milliseconds to sleep. + * @return {Promise} A promise that resolves after the specified time. + */ +function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** + * Generates chunks of text asynchronously, yielding one chunk at a time. + * @async + * @generator + * @yields {string} A chunk of text from the data array. + */ +async function* generateText() { + for (const chunk of data) { + yield chunk; + await sleep(1000); + } +} + +exports.genStream = functions.https.onCall(async (request, response) => { + if (response && response.acceptsStreaming) { + for await (const chunk of generateText()) { + console.log("got chunk", chunk); + response.write({chunk}); + } + } + return data.join(" "); +}); + +exports.genStreamError = functions.https.onCall(async (request, response) => { + if (response && response.acceptsStreaming) { + for await (const chunk of generateText()) { + console.log("got chunk", chunk); + response.write({chunk}); + } + throw new Error("BOOM"); + } +}); + +exports.genStreamNoReturn = functions.https.onCall( + async (request, response) => { + if (response && response.acceptsStreaming) { + for await (const chunk of generateText()) { + console.log("got chunk", chunk); + response.write({chunk}); + } + } + }, +); diff --git a/firebase-functions/src/androidTest/java/com/google/firebase/functions/StreamTests.kt b/firebase-functions/src/androidTest/java/com/google/firebase/functions/StreamTests.kt new file mode 100644 index 00000000000..0b9bb330b69 --- /dev/null +++ b/firebase-functions/src/androidTest/java/com/google/firebase/functions/StreamTests.kt @@ -0,0 +1,203 @@ +package com.google.firebase.functions.ktx + +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.google.common.truth.Truth.assertThat +import com.google.firebase.Firebase +import com.google.firebase.functions.FirebaseFunctions +import com.google.firebase.functions.StreamResponse +import com.google.firebase.functions.StreamResponse.Message +import com.google.firebase.functions.StreamResponse.Result +import com.google.firebase.functions.functions +import com.google.firebase.initialize +import java.util.concurrent.TimeUnit +import kotlinx.coroutines.reactive.asFlow +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import org.junit.After +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription + +@RunWith(AndroidJUnit4::class) +class StreamTests { + + private lateinit var functions: FirebaseFunctions + var onNextList = mutableListOf() + private lateinit var subscriber: Subscriber + private var throwable: Throwable? = null + private var isComplete = false + + @Before + fun setup() { + Firebase.initialize(ApplicationProvider.getApplicationContext()) + functions = Firebase.functions + subscriber = + object : Subscriber { + override fun onSubscribe(subscription: Subscription?) { + subscription?.request(1) + } + + override fun onNext(streamResponse: StreamResponse) { + onNextList.add(streamResponse) + } + + override fun onError(t: Throwable?) { + throwable = t + } + + override fun onComplete() { + isComplete = true + } + } + } + + @After + fun clear() { + onNextList.clear() + throwable = null + isComplete = false + } + + @Test + fun genStream_withPublisher_receivesMessagesAndFinalResult() { + val input = mapOf("data" to "Why is the sky blue") + val function = functions.getHttpsCallable("genStream") + + function.stream(input).subscribe(subscriber) + + Thread.sleep(8000) + val messages = onNextList.filterIsInstance() + val results = onNextList.filterIsInstance() + assertThat(messages.map { it.data.data.toString() }) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(results).hasSize(1) + assertThat(results.first().data.data.toString()).isEqualTo("hello world this is cool") + assertThat(throwable).isNull() + assertThat(isComplete).isTrue() + } + + @Test + fun genStream_withFlow_receivesMessagesAndFinalResult() = runBlocking { + val input = mapOf("data" to "Why is the sky blue") + val function = functions.getHttpsCallable("genStream") + + val flow = function.stream(input).asFlow() + val receivedResponses = mutableListOf() + try { + withTimeout(8000) { flow.collect { response -> receivedResponses.add(response) } } + isComplete = true + } catch (e: Throwable) { + throwable = e + } + + val messages = receivedResponses.filterIsInstance() + val results = receivedResponses.filterIsInstance() + assertThat(messages.map { it.data.data.toString() }) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(results).hasSize(1) + assertThat(results.first().data.data.toString()).isEqualTo("hello world this is cool") + assertThat(throwable).isNull() + assertThat(isComplete).isTrue() + } + + @Test + fun genStreamError_receivesErrorAndStops() { + val input = mapOf("data" to "Why is the sky blue") + val function = + functions.getHttpsCallable("genStreamError").withTimeout(800, TimeUnit.MILLISECONDS) + + function.stream(input).subscribe(subscriber) + Thread.sleep(2000) + + val messages = onNextList.filterIsInstance() + val onNextStringList = messages.map { it.data.data.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + ) + assertThat(throwable).isNotNull() + assertThat(isComplete).isFalse() + } + + @Test + fun genStreamNoReturn_receivesOnlyMessages() { + val input = mapOf("data" to "Why is the sky blue") + val function = functions.getHttpsCallable("genStreamNoReturn") + + function.stream(input).subscribe(subscriber) + Thread.sleep(8000) + + val messages = onNextList.filterIsInstance() + val results = onNextList.filterIsInstance() + + val onNextStringList = messages.map { it.data.data.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + "{chunk=world}", + "{chunk=this}", + "{chunk=is}", + "{chunk=cool}" + ) + assertThat(results).isEmpty() + assertThat(throwable).isNull() + assertThat(isComplete).isFalse() + } + + @Test + fun genStream_cancelStream_receivesPartialMessagesAndError() { + val input = mapOf("data" to "Why is the sky blue") + val function = functions.getHttpsCallable("genStreamNoReturn") + val publisher = function.stream(input) + var subscription: Subscription? = null + val cancelableSubscriber = + object : Subscriber { + override fun onSubscribe(s: Subscription?) { + subscription = s + s?.request(1) + } + + override fun onNext(streamResponse: StreamResponse) { + onNextList.add(streamResponse) + } + + override fun onError(t: Throwable?) { + throwable = t + } + + override fun onComplete() { + isComplete = true + } + } + + publisher.subscribe(cancelableSubscriber) + Thread.sleep(500) + subscription?.cancel() + Thread.sleep(6000) + + val messages = onNextList.filterIsInstance() + val onNextStringList = messages.map { it.data.data.toString() } + assertThat(onNextStringList) + .containsExactly( + "{chunk=hello}", + ) + assertThat(throwable).isNotNull() + assertThat(requireNotNull(throwable).message).isEqualTo("Stream was canceled") + assertThat(isComplete).isFalse() + } +} diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt b/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt index 824670c4346..48e5b9d7904 100644 --- a/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt +++ b/firebase-functions/src/main/java/com/google/firebase/functions/FirebaseFunctions.kt @@ -45,6 +45,7 @@ import okhttp3.RequestBody import okhttp3.Response import org.json.JSONException import org.json.JSONObject +import org.reactivestreams.Publisher /** FirebaseFunctions lets you call Cloud Functions for Firebase. */ public class FirebaseFunctions @@ -311,6 +312,31 @@ internal constructor( return tcs.task } + internal fun stream( + name: String, + data: Any?, + options: HttpsCallOptions + ): Publisher { + val url = getURL(name) + Preconditions.checkNotNull(url, "url cannot be null") + val task = + providerInstalled.task.continueWithTask(executor) { + contextProvider.getContext(options.limitedUseAppCheckTokens) + } + + return PublisherStream(url, data, options, client, serializer, task, executor) + } + + internal fun stream(url: URL, data: Any?, options: HttpsCallOptions): Publisher { + Preconditions.checkNotNull(url, "url cannot be null") + val task = + providerInstalled.task.continueWithTask(executor) { + contextProvider.getContext(options.limitedUseAppCheckTokens) + } + + return PublisherStream(url, data, options, client, this.serializer, task, executor) + } + public companion object { /** A task that will be resolved once ProviderInstaller has installed what it needs to. */ private val providerInstalled = TaskCompletionSource() diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt b/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt index 88db9db4ee4..3c09aa6aa50 100644 --- a/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt +++ b/firebase-functions/src/main/java/com/google/firebase/functions/HttpsCallableReference.kt @@ -17,6 +17,7 @@ import androidx.annotation.VisibleForTesting import com.google.android.gms.tasks.Task import java.net.URL import java.util.concurrent.TimeUnit +import org.reactivestreams.Publisher /** A reference to a particular Callable HTTPS trigger in Cloud Functions. */ public class HttpsCallableReference { @@ -125,6 +126,74 @@ public class HttpsCallableReference { } } + /** + * Streams data to the specified HTTPS endpoint asynchronously. + * + * The data passed into the trigger can be any of the following types: + * + * * Any primitive type, including null, int, long, float, and boolean. + * * [String] + * * [List<?>][java.util.List], where the contained objects are also one of these + * types. + * * [Map<String, ?>>][java.util.Map], where the values are also one of these + * types. + * * [org.json.JSONArray] + * * [org.json.JSONObject] + * * [org.json.JSONObject.NULL] + * + * If the returned task fails, the exception will be one of the following types: + * + * * [java.io.IOException] + * - if the HTTPS request failed to connect. + * * [FirebaseFunctionsException] + * - if the request connected, but the function returned an error. + * + * The request to the Cloud Functions backend made by this method automatically includes a + * Firebase Instance ID token to identify the app instance. If a user is logged in with Firebase + * Auth, an auth token for the user will also be automatically included. + * + * Firebase Instance ID sends data to the Firebase backend periodically to collect information + * regarding the app instance. To stop this, see + * [com.google.firebase.iid.FirebaseInstanceId.deleteInstanceId]. It will resume with a new + * Instance ID the next time you call this method. + * + * @param data Parameters to pass to the endpoint. + * @return [Publisher] that will be completed when the streaming operation has finished. + * @see org.json.JSONArray + * @see org.json.JSONObject + * @see java.io.IOException + * @see FirebaseFunctionsException + */ + public fun stream(data: Any?): Publisher { + return if (name != null) { + functionsClient.stream(name, data, options) + } else { + functionsClient.stream(requireNotNull(url), data, options) + } + } + + /** + * Streams data to the specified HTTPS endpoint asynchronously without arguments. + * + * The request to the Cloud Functions backend made by this method automatically includes a + * Firebase Instance ID token to identify the app instance. If a user is logged in with Firebase + * Auth, an auth token for the user will also be automatically included. + * + * Firebase Instance ID sends data to the Firebase backend periodically to collect information + * regarding the app instance. To stop this, see + * [com.google.firebase.iid.FirebaseInstanceId.deleteInstanceId]. It will resume with a new + * Instance ID the next time you call this method. + * + * @return [Publisher] that will be completed when the streaming operation has finished. + */ + public fun stream(): Publisher { + return if (name != null) { + functionsClient.stream(name, null, options) + } else { + functionsClient.stream(requireNotNull(url), null, options) + } + } + /** * Changes the timeout for calls from this instance of Functions. The default is 60 seconds. * diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/PublisherStream.kt b/firebase-functions/src/main/java/com/google/firebase/functions/PublisherStream.kt new file mode 100644 index 00000000000..3fa68c3f535 --- /dev/null +++ b/firebase-functions/src/main/java/com/google/firebase/functions/PublisherStream.kt @@ -0,0 +1,250 @@ +package com.google.firebase.functions + +import com.google.android.gms.tasks.Task +import java.io.BufferedReader +import java.io.IOException +import java.io.InputStream +import java.io.InputStreamReader +import java.io.InterruptedIOException +import java.net.URL +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.Executor +import okhttp3.Call +import okhttp3.Callback +import okhttp3.MediaType +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody +import okhttp3.Response +import org.json.JSONObject +import org.reactivestreams.Publisher +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription + +internal class PublisherStream( + private val url: URL, + private val data: Any?, + private val options: HttpsCallOptions, + private val client: OkHttpClient, + private val serializer: Serializer, + private val contextTask: Task, + private val executor: Executor +) : Publisher { + + private val subscribers = ConcurrentLinkedQueue>() + private var activeCall: Call? = null + + override fun subscribe(subscriber: Subscriber) { + subscribers.add(subscriber) + subscriber.onSubscribe( + object : Subscription { + override fun request(n: Long) { + startStreaming() + } + + override fun cancel() { + cancelStream() + subscribers.remove(subscriber) + } + } + ) + } + + private fun startStreaming() { + contextTask.addOnCompleteListener(executor) { contextTask -> + if (!contextTask.isSuccessful) { + notifyError( + FirebaseFunctionsException( + "Error retrieving context", + FirebaseFunctionsException.Code.INTERNAL, + null, + contextTask.exception + ) + ) + return@addOnCompleteListener + } + + val context = contextTask.result + val callClient = options.apply(client) + val requestBody = + RequestBody.create( + MediaType.parse("application/json"), + JSONObject(mapOf("data" to serializer.encode(data))).toString() + ) + val requestBuilder = + Request.Builder().url(url).post(requestBody).header("Accept", "text/event-stream") + applyCommonConfiguration(requestBuilder, context) + val request = requestBuilder.build() + val call = callClient.newCall(request) + activeCall = call + + call.enqueue( + object : Callback { + override fun onFailure(call: Call, e: IOException) { + val code: FirebaseFunctionsException.Code = + if (e is InterruptedIOException) { + FirebaseFunctionsException.Code.DEADLINE_EXCEEDED + } else { + FirebaseFunctionsException.Code.INTERNAL + } + notifyError(FirebaseFunctionsException(code.name, code, null, e)) + } + + override fun onResponse(call: Call, response: Response) { + validateResponse(response) + val bodyStream = response.body()?.byteStream() + if (bodyStream != null) { + processSSEStream(bodyStream) + } else { + notifyError( + FirebaseFunctionsException( + "Response body is null", + FirebaseFunctionsException.Code.INTERNAL, + null + ) + ) + } + } + } + ) + } + } + + private fun cancelStream() { + activeCall?.cancel() + notifyError( + FirebaseFunctionsException( + "Stream was canceled", + FirebaseFunctionsException.Code.CANCELLED, + null + ) + ) + } + + private fun applyCommonConfiguration( + requestBuilder: Request.Builder, + context: HttpsCallableContext? + ) { + context?.authToken?.let { requestBuilder.header("Authorization", "Bearer $it") } + context?.instanceIdToken?.let { requestBuilder.header("Firebase-Instance-ID-Token", it) } + context?.appCheckToken?.let { requestBuilder.header("X-Firebase-AppCheck", it) } + } + + private fun processSSEStream(inputStream: InputStream) { + BufferedReader(InputStreamReader(inputStream)).use { reader -> + try { + reader.lineSequence().forEach { line -> + val dataChunk = + when { + line.startsWith("data:") -> line.removePrefix("data:") + line.startsWith("result:") -> line.removePrefix("result:") + else -> return@forEach + } + try { + val json = JSONObject(dataChunk) + when { + json.has("message") -> + serializer.decode(json.opt("message"))?.let { + notifyData(StreamResponse.Message(data = HttpsCallableResult(it))) + } + json.has("error") -> { + serializer.decode(json.opt("error"))?.let { + notifyError( + FirebaseFunctionsException( + it.toString(), + FirebaseFunctionsException.Code.INTERNAL, + it + ) + ) + } + } + json.has("result") -> { + serializer.decode(json.opt("result"))?.let { + notifyData(StreamResponse.Result(data = HttpsCallableResult(it))) + notifyComplete() + } + return + } + } + } catch (e: Throwable) { + notifyError( + FirebaseFunctionsException( + "Invalid JSON: $dataChunk", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + ) + } + } + notifyError( + FirebaseFunctionsException( + "Stream ended unexpectedly without completion", + FirebaseFunctionsException.Code.INTERNAL, + null + ) + ) + } catch (e: Exception) { + notifyError( + FirebaseFunctionsException( + e.message ?: "Error reading stream", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + ) + } + } + } + + private fun notifyData(data: StreamResponse?) { + for (subscriber in subscribers) { + subscriber.onNext(data) + } + } + + private fun notifyError(e: FirebaseFunctionsException) { + for (subscriber in subscribers) { + subscriber.onError(e) + } + subscribers.clear() + } + + private fun notifyComplete() { + for (subscriber in subscribers) { + subscriber.onComplete() + } + subscribers.clear() + } + + private fun validateResponse(response: Response) { + if (response.isSuccessful) return + + val htmlContentType = "text/html; charset=utf-8" + val trimMargin: String + if (response.code() == 404 && response.header("Content-Type") == htmlContentType) { + trimMargin = """URL not found. Raw response: ${response.body()?.string()}""".trimMargin() + throw FirebaseFunctionsException( + trimMargin, + FirebaseFunctionsException.Code.fromHttpStatus(response.code()), + null + ) + } + + val text = response.body()?.string() ?: "" + val error: Any? + try { + val json = JSONObject(text) + error = serializer.decode(json.opt("error")) + } catch (e: Throwable) { + throw FirebaseFunctionsException( + "${e.message} Unexpected Response:\n$text ", + FirebaseFunctionsException.Code.INTERNAL, + e + ) + } + throw FirebaseFunctionsException( + error.toString(), + FirebaseFunctionsException.Code.INTERNAL, + error + ) + } +} diff --git a/firebase-functions/src/main/java/com/google/firebase/functions/StreamResponse.kt b/firebase-functions/src/main/java/com/google/firebase/functions/StreamResponse.kt new file mode 100644 index 00000000000..475fcb748cf --- /dev/null +++ b/firebase-functions/src/main/java/com/google/firebase/functions/StreamResponse.kt @@ -0,0 +1,36 @@ +package com.google.firebase.functions + +/** + * Represents a response from a Server-Sent Event (SSE) stream. + * + * The SSE stream consists of two types of responses: + * - [Message]: Represents an intermediate event pushed from the server. + * - [Result]: Represents the final response that signifies the stream has ended. + */ +public abstract class StreamResponse private constructor(public val data: HttpsCallableResult) { + + /** + * An event message received during the stream. + * + * Messages are intermediate data chunks sent by the server before the final result. + * + * Example SSE format: + * ```json + * data: { "message": { "chunk": "foo" } } + * ``` + */ + public class Message(data: HttpsCallableResult) : StreamResponse(data) + + /** + * The final response that terminates the stream. + * + * This result is sent as the last message in the stream and indicates that no further messages + * will be received. + * + * Example SSE format: + * ```json + * data: { "result": { "text": "foo bar" } } + * ``` + */ + public class Result(data: HttpsCallableResult) : StreamResponse(data) +}