From ffd13e90c04fa9263d2147e081f72c2c259081c1 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 21:17:28 -0800 Subject: [PATCH 1/6] support returning Rx[SeverSentEvent] --- .../wvlet/airframe/http/grpc/GrpcClient.scala | 12 ++-- .../airframe/http/grpc/GrpcClientCalls.scala | 6 +- .../grpc/internal/GrpcRequestHandler.scala | 16 ++--- .../http/netty/NettyRequestHandler.scala | 34 +++++++-- .../http/netty/NettyResponseHandler.scala | 7 +- .../wvlet/airframe/http/netty/SSETest.scala | 71 ++++++++++++++++++- .../router/HttpEndpointExecutionContext.scala | 5 +- .../wvlet/airframe/http/HttpMessage.scala | 3 +- .../wvlet/airframe/http/ServerSentEvent.scala | 11 ++- .../wvlet/airframe/rx/RxBlockingQueue.scala | 2 +- .../scala/wvlet/airframe/rx/RxQueue.scala | 2 +- .../scala/wvlet/airframe/rx/RxRunner.scala | 2 +- .../scala/wvlet/airframe/rx/RxSource.scala | 5 +- 13 files changed, 143 insertions(+), 33 deletions(-) diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala index 828c714f13..7120e3b82c 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala @@ -199,26 +199,26 @@ object GrpcClient extends LogSupport { new BlockingRxObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: A): Unit = { - toRx.add(OnNext(v)) + toRx.addEvent(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.add(OnError(t)) + toRx.addEvent(OnError(t)) } override def onCompleted(): Unit = { - toRx.add(OnCompletion) + toRx.addEvent(OnCompletion) } } private class RxObserver[A] extends StreamObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: A): Unit = { - toRx.add(OnNext(v)) + toRx.addEvent(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.add(OnError(t)) + toRx.addEvent(OnError(t)) } override def onCompleted(): Unit = { - toRx.add(OnCompletion) + toRx.addEvent(OnCompletion) } } diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala index 4a3cc2679b..2b69eb098d 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala @@ -34,13 +34,13 @@ object GrpcClientCalls extends LogSupport { new BlockingStreamObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: Any): Unit = { - toRx.add(OnNext(v)) + toRx.addEvent(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.add(OnError(t)) + toRx.addEvent(OnError(t)) } override def onCompleted(): Unit = { - toRx.add(OnCompletion) + toRx.addEvent(OnCompletion) } } diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala index 9dbb148ad3..592a4e0199 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala @@ -195,21 +195,21 @@ class GrpcRequestHandler( invokeServerMethod Try(readStreamingInput(grpcContext, codec, value)) match { case Success(v) => - rx.add(OnNext(v)) + rx.addEvent(OnNext(v)) case Failure(e) => reportError(e) - rx.add(OnError(e)) + rx.addEvent(OnError(e)) } } override def onError(t: Throwable): Unit = { reportError(t) requestLogger.logError(t, grpcContext, rpcContext) - rx.add(OnError(t)) + rx.addEvent(OnError(t)) responseObserver.onError(GrpcException.wrap(t)) } override def onCompleted(): Unit = { invokeServerMethod - rx.add(OnCompletion) + rx.addEvent(OnCompletion) promise.future.onComplete { case Success(value) => requestLogger.logRPC(grpcContext, rpcContext) @@ -261,19 +261,19 @@ class GrpcRequestHandler( Try(readStreamingInput(grpcContext, codec, value)) match { case Success(v) => // Add a log for each client-side stream message - rx.add(OnNext(v)) + rx.addEvent(OnNext(v)) case Failure(e) => - rx.add(OnError(e)) + rx.addEvent(OnError(e)) } } override def onError(t: Throwable): Unit = { reportError(t) - rx.add(OnError(t)) + rx.addEvent(OnError(t)) responseObserver.onError(t) } override def onCompleted(): Unit = { invokeServerMethod - rx.add(OnCompletion) + rx.addEvent(OnCompletion) promise.future.onComplete { case Success(v) => v match { diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index 5c67c2f677..47a253e4db 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -26,7 +26,8 @@ import wvlet.airframe.http.{ HttpStatus, RPCException, RPCStatus, - ServerAddress + ServerAddress, + ServerSentEvent } import wvlet.airframe.rx.{OnCompletion, OnError, OnNext, Rx, RxRunner} import wvlet.log.LogSupport @@ -103,8 +104,23 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi RxRunner.run(rxResponse) { case OnNext(v) => - val nettyResponse = toNettyResponse(v.asInstanceOf[Response]) + val resp = v.asInstanceOf[Response] + val nettyResponse = toNettyResponse(resp) writeResponse(msg, ctx, nettyResponse) + + if (resp.isContentTypeEventStream) { + // Read SSE stream + val c = RxRunner.runContinuously(resp.events) { + case OnNext(e: ServerSentEvent) => + warn(e) + val event = e.toContent + val buf = Unpooled.wrappedBuffer(event.getBytes("UTF-8")) + ctx.writeAndFlush(new DefaultHttpContent(buf)) + case _ => + warn(s"completed") + ctx.channel().closeFuture().addListener(ChannelFutureListener.CLOSE) + } + } case OnError(ex) => // This path manages unhandled exceptions val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse @@ -122,7 +138,11 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi } private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = { - val keepAlive = HttpStatus.ofCode(resp.status().code()).isSuccessful && HttpUtil.isKeepAlive(req) + val isEventStream = resp.headers().get(HttpHeader.ContentType).contains("text/event-stream") + + val keepAlive: Boolean = + HttpStatus.ofCode(resp.status().code()).isSuccessful && (HttpUtil.isKeepAlive(req) || isEventStream) + if (keepAlive) { if (!req.protocolVersion().isKeepAliveDefault) { resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) @@ -139,8 +159,12 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi } object NettyRequestHandler { - def toNettyResponse(response: Response): DefaultFullHttpResponse = { - val r = if (response.message.isEmpty) { + def toNettyResponse(response: Response): DefaultHttpResponse = { + val r = if (response.isContentTypeEventStream) { + val res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode)) + res.headers().set(HttpHeaderNames.CACHE_CONTROL, HttpHeaderValues.NO_CACHE) + res + } else if (response.message.isEmpty) { val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode)) // Need to set the content length properly to return the response in Netty HttpUtil.setContentLength(res, 0) diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyResponseHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyResponseHandler.scala index 5de26c2357..3e96e6e4fd 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyResponseHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyResponseHandler.scala @@ -15,10 +15,11 @@ package wvlet.airframe.http.netty import wvlet.airframe.codec.{JSONCodec, MessageCodec, MessageCodecFactory} import wvlet.airframe.http.HttpMessage.{Request, Response} -import wvlet.airframe.http.{Http, HttpStatus} +import wvlet.airframe.http.{Http, HttpStatus, ServerSentEvent} import wvlet.airframe.http.router.{ResponseHandler, Route} import wvlet.airframe.msgpack.spi.MsgPack import wvlet.airframe.surface.{Primitive, Surface} +import wvlet.airframe.rx.Rx import wvlet.log.LogSupport class NettyResponseHandler extends ResponseHandler[Request, Response] with LogSupport { @@ -36,6 +37,10 @@ class NettyResponseHandler extends ResponseHandler[Request, Response] with LogSu case s: String if !request.acceptsMsgPack => newResponse(route, request, responseSurface) .withContent(s) + case r: Rx[_] if responseSurface.typeArgs(0).rawType == classOf[ServerSentEvent] => + val resp = newResponse(route, request, responseSurface).withContentType("text/event-stream") + resp.events = r.asInstanceOf[Rx[ServerSentEvent]] + resp case _ => val rs = codecFactory.of(responseSurface) val msgpack: Array[Byte] = rs match { diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala index 7f294c78cb..c9b536047a 100644 --- a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala @@ -13,10 +13,10 @@ */ package wvlet.airframe.http.netty -import wvlet.airframe.http.{Endpoint, Http, RxRouter, ServerSentEvent, ServerSentEventHandler} +import wvlet.airframe.http.{Endpoint, Http, HttpMethod, RxRouter, ServerSentEvent, ServerSentEventHandler} import wvlet.airframe.http.HttpMessage.Response import wvlet.airframe.http.client.AsyncClient -import wvlet.airframe.rx.Rx +import wvlet.airframe.rx.{Rx, RxBlockingQueue} import wvlet.airspec.AirSpec class SSEApi { @@ -46,6 +46,28 @@ class SSEApi { |data: need to retry |""".stripMargin) } + + @Endpoint(method = HttpMethod.POST, path = "/v1/sse-stream") + def sseStream(): Rx[ServerSentEvent] = { + val queue = new RxBlockingQueue[ServerSentEvent]() + new Thread(new Runnable { + override def run(): Unit = { + queue.add(ServerSentEvent(data = "hello stream")) + // Thread.sleep(100) + queue.add(ServerSentEvent(data = "another stream message\nwith two lines")) + // Thread.sleep(50) + queue.add(ServerSentEvent(event = Some("custom-event"), data = "hello custom event")) + Thread.sleep(20) + queue.add(ServerSentEvent(id = Some("123"), data = "hello again")) + Thread.sleep(10) + queue.add(ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2")) + Thread.sleep(30) + queue.add(ServerSentEvent(retry = Some(1000), data = "need to retry")) + queue.stop() + } + }).start() + queue + } } class SSETest extends AirSpec { @@ -96,4 +118,49 @@ class SSETest extends AirSpec { events shouldBe expected } } + + test("read sse-stream") { (client: AsyncClient) => + val buf = List.newBuilder[ServerSentEvent] + val completed = Rx.variable(false) + val rx = client.send( + Http + .POST("/v1/sse-stream") + .withEventHandler(new ServerSentEventHandler { + override def onError(e: Throwable): Unit = { + completed := true + } + override def onCompletion(): Unit = { + completed := true + } + override def onEvent(e: ServerSentEvent): Unit = { + info(e) + buf += e + } + }) + ) + + rx.join(completed) + .filter(_._2 == true) + .map(_._1) + .map { resp => + resp.statusCode shouldBe 200 + debug(resp) + + val events = buf.result() + val expected = List( + ServerSentEvent(data = "hello stream"), + ServerSentEvent(data = "another stream message\nwith two lines"), + ServerSentEvent(event = Some("custom-event"), data = "hello custom event"), + ServerSentEvent(id = Some("123"), data = "hello again"), + ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"), + ServerSentEvent(retry = Some(1000), data = "need to retry") + ) + + trace(events.mkString("\n")) + // trace(expected.mkString("\n")) + events shouldBe expected + } + + } + } diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala index da9bbd9b2d..2f48c915cd 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala @@ -16,7 +16,7 @@ package wvlet.airframe.http.router import java.lang.reflect.InvocationTargetException import wvlet.airframe.codec.{MessageCodec, MessageCodecFactory} import wvlet.airframe.control.ThreadUtil -import wvlet.airframe.http.{HttpBackend, HttpContext, HttpRequestAdapter} +import wvlet.airframe.http.{Http, HttpBackend, HttpContext, HttpMethod, HttpRequestAdapter, HttpStatus, ServerSentEvent} import wvlet.log.LogSupport import java.util.concurrent.Executors @@ -61,6 +61,9 @@ class HttpEndpointExecutionContext[Req: HttpRequestAdapter, Resp, F[_]]( case valueCls if backend.isRawResponseType(valueCls) => // Use Backend Future (e.g., Finagle Future or Rx) result.asInstanceOf[F[Resp]] + case valueCls if valueCls == classOf[ServerSentEvent] => + // Rx[ServerSentEvent] + backend.toFuture(responseHandler.toHttpResponse(route, request, route.returnTypeSurface, result)) case other => // If X is other type, convert X into an HttpResponse backend.mapF( diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/HttpMessage.scala b/airframe-http/src/main/scala/wvlet/airframe/http/HttpMessage.scala index f4ece3d9ab..40f2af3829 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/HttpMessage.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/HttpMessage.scala @@ -285,7 +285,8 @@ object HttpMessage { case class Response( status: HttpStatus = HttpStatus.Ok_200, header: HttpMultiMap = HttpMultiMap.empty, - message: Message = EmptyMessage + message: Message = EmptyMessage, + private[http] var events: Rx[ServerSentEvent] = Rx.empty ) extends HttpMessage[Response] { override def toString: String = s"Response(${status},${header})" diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala b/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala index fa3060689c..867deeb185 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala @@ -21,7 +21,16 @@ case class ServerSentEvent( retry: Option[Long] = None, // event data string. If multiple data entries are reported, concatenated with newline data: String -) +) { + def toContent: String = { + val b = Seq.newBuilder[String] + id.foreach(x => b += s"id: $x") + event.foreach(x => b += s"event: $x") + retry.foreach(x => b += s"retry: $x") + data.split("\n").foreach(x => b += s"data: $x") + b.result().mkString("\n") + } +} object ServerSentEventHandler { def empty: ServerSentEventHandler = new ServerSentEventHandler { diff --git a/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala b/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala index 009375785c..368c809cb1 100644 --- a/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala +++ b/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala @@ -22,7 +22,7 @@ class RxBlockingQueue[A] extends RxSource[A] { private val blockingQueue = new LinkedBlockingQueue[RxEvent]() - override def add(event: RxEvent): Unit = { + override def addEvent(event: RxEvent): Unit = { blockingQueue.add(event) } override def next: Rx[RxEvent] = { diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala index bcfb192d9e..655f12b290 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala @@ -28,7 +28,7 @@ class RxQueue[A]() extends RxSource[A] with LogSupport { private var queue = scala.collection.immutable.Queue.empty[RxEvent] private var waiting: Option[Promise[RxEvent]] = None - override def add(event: RxEvent): Unit = { + override def addEvent(event: RxEvent): Unit = { synchronized { queue = queue.enqueue(event) waiting match { diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala index 62e1b41b00..75431a77a4 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala @@ -576,7 +576,7 @@ class RxRunner( c1, Cancelable { () => toContinue = false - source.add(OnError(new InterruptedException("cancelled"))) + source.addEvent(OnError(new InterruptedException("cancelled"))) } ) } diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala index 1f35c8f9ab..aeacace6f7 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala @@ -17,7 +17,8 @@ package wvlet.airframe.rx * Rx implementation where the data is provided from an external process. */ trait RxSource[A] extends Rx[A] { - def add(ev: RxEvent): Unit + def add(e: A): Unit = addEvent(OnNext(e)) + def addEvent(ev: RxEvent): Unit def next: Rx[RxEvent] - def stop(): Unit = add(OnCompletion) + def stop(): Unit = addEvent(OnCompletion) } From 825562bf1530fa5b5c777c3cff5fb6101bbc898d Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 22:11:38 -0800 Subject: [PATCH 2/6] fix tests --- .../http/netty/NettyRequestHandler.scala | 14 +-- .../wvlet/airframe/http/netty/SSETest.scala | 90 +++++++++---------- .../http/client/JavaHttpClientChannel.scala | 3 +- .../wvlet/airframe/http/ServerSentEvent.scala | 2 +- 4 files changed, 52 insertions(+), 57 deletions(-) diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index 47a253e4db..a1c228a11e 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -108,17 +108,18 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi val nettyResponse = toNettyResponse(resp) writeResponse(msg, ctx, nettyResponse) - if (resp.isContentTypeEventStream) { + if (resp.isContentTypeEventStream && resp.message.isEmpty) { // Read SSE stream val c = RxRunner.runContinuously(resp.events) { case OnNext(e: ServerSentEvent) => - warn(e) + info(e) val event = e.toContent - val buf = Unpooled.wrappedBuffer(event.getBytes("UTF-8")) + val buf = Unpooled.copiedBuffer(event.getBytes("UTF-8")) ctx.writeAndFlush(new DefaultHttpContent(buf)) case _ => warn(s"completed") - ctx.channel().closeFuture().addListener(ChannelFutureListener.CLOSE) + val f = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + f.addListener(ChannelFutureListener.CLOSE) } } case OnError(ex) => @@ -160,9 +161,12 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi object NettyRequestHandler { def toNettyResponse(response: Response): DefaultHttpResponse = { - val r = if (response.isContentTypeEventStream) { + val r = if (response.isContentTypeEventStream && response.message.isEmpty) { val res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode)) + res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/event-stream") + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) res.headers().set(HttpHeaderNames.CACHE_CONTROL, HttpHeaderValues.NO_CACHE) + res.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) res } else if (response.message.isEmpty) { val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode)) diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala index c9b536047a..8fcd09cd40 100644 --- a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala @@ -79,88 +79,78 @@ class SSETest extends AirSpec { ) } - test("read sse events") { (client: AsyncClient) => - val buf = List.newBuilder[ServerSentEvent] - val completed = Rx.variable(false) + test("read sse-events") { (client: AsyncClient) => + val queue = new RxBlockingQueue[ServerSentEvent]() val rx = client.send( Http .GET("/v1/sse") .withEventHandler(new ServerSentEventHandler { override def onError(e: Throwable): Unit = { - completed := true + queue.stop() } override def onCompletion(): Unit = { - completed := true + queue.stop() } override def onEvent(e: ServerSentEvent): Unit = { - buf += e + queue.add(e) } }) ) - rx.join(completed) - .filter(_._2 == true) - .map(_._1) - .map { resp => - resp.statusCode shouldBe 200 - val events = buf.result() - val expected = List( - ServerSentEvent(data = "hello stream"), - ServerSentEvent(data = "another stream message\nwith two lines"), - ServerSentEvent(event = Some("custom-event"), data = "hello custom event"), - ServerSentEvent(id = Some("123"), data = "hello again"), - ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"), - ServerSentEvent(retry = Some(1000), data = "need to retry") - ) + rx.map { resp => + resp.statusCode shouldBe 200 - trace(events.mkString("\n")) - trace(expected.mkString("\n")) - events shouldBe expected - } + val events = queue.toSeq.toList + val expected = List( + ServerSentEvent(data = "hello stream"), + ServerSentEvent(data = "another stream message\nwith two lines"), + ServerSentEvent(event = Some("custom-event"), data = "hello custom event"), + ServerSentEvent(id = Some("123"), data = "hello again"), + ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"), + ServerSentEvent(retry = Some(1000), data = "need to retry") + ) + + trace(events.mkString("\n")) + events shouldBe expected + } } test("read sse-stream") { (client: AsyncClient) => - val buf = List.newBuilder[ServerSentEvent] - val completed = Rx.variable(false) + val queue = new RxBlockingQueue[ServerSentEvent]() val rx = client.send( Http .POST("/v1/sse-stream") .withEventHandler(new ServerSentEventHandler { override def onError(e: Throwable): Unit = { - completed := true + queue.stop() } override def onCompletion(): Unit = { - completed := true + queue.stop() } override def onEvent(e: ServerSentEvent): Unit = { - info(e) - buf += e + debug(e) + queue.add(e) } }) ) - rx.join(completed) - .filter(_._2 == true) - .map(_._1) - .map { resp => - resp.statusCode shouldBe 200 - debug(resp) - - val events = buf.result() - val expected = List( - ServerSentEvent(data = "hello stream"), - ServerSentEvent(data = "another stream message\nwith two lines"), - ServerSentEvent(event = Some("custom-event"), data = "hello custom event"), - ServerSentEvent(id = Some("123"), data = "hello again"), - ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"), - ServerSentEvent(retry = Some(1000), data = "need to retry") - ) + rx.map { resp => + resp.statusCode shouldBe 200 - trace(events.mkString("\n")) - // trace(expected.mkString("\n")) - events shouldBe expected - } + val events = queue.toSeq.toList + val expected = List( + ServerSentEvent(data = "hello stream"), + ServerSentEvent(data = "another stream message\nwith two lines"), + ServerSentEvent(event = Some("custom-event"), data = "hello custom event"), + ServerSentEvent(id = Some("123"), data = "hello again"), + ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"), + ServerSentEvent(retry = Some(1000), data = "need to retry") + ) + trace(events.mkString("\n")) + // trace(expected.mkString("\n")) + events shouldBe expected + } } } diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/client/JavaHttpClientChannel.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/client/JavaHttpClientChannel.scala index 204b563ae1..e81871ad78 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/client/JavaHttpClientChannel.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/client/JavaHttpClientChannel.scala @@ -182,7 +182,8 @@ class JavaHttpClientChannel(val destination: ServerAddress, private[http] val co executor.execute(new Runnable { override def run(): Unit = { try { - withResource(new BufferedReader(new InputStreamReader(httpResponse.body()))) { reader => + val body = httpResponse.body() + withResource(new BufferedReader(new InputStreamReader(body))) { reader => var id: Option[String] = None var eventType: Option[String] = None var retry: Option[Long] = None diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala b/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala index 867deeb185..b12f957745 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/ServerSentEvent.scala @@ -28,7 +28,7 @@ case class ServerSentEvent( event.foreach(x => b += s"event: $x") retry.foreach(x => b += s"retry: $x") data.split("\n").foreach(x => b += s"data: $x") - b.result().mkString("\n") + s"${b.result().mkString("\n")}\n\n" } } From 432d1503f9b392e2f3900995d389422ca1050fb1 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 22:18:13 -0800 Subject: [PATCH 3/6] Remove comments --- .../scala/wvlet/airframe/http/netty/NettyRequestHandler.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index a1c228a11e..2fbb9241fc 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -112,12 +112,10 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi // Read SSE stream val c = RxRunner.runContinuously(resp.events) { case OnNext(e: ServerSentEvent) => - info(e) val event = e.toContent val buf = Unpooled.copiedBuffer(event.getBytes("UTF-8")) ctx.writeAndFlush(new DefaultHttpContent(buf)) case _ => - warn(s"completed") val f = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) f.addListener(ChannelFutureListener.CLOSE) } From 4001d296274053d76c89ad95582c8ffbb7ee043c Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 22:28:37 -0800 Subject: [PATCH 4/6] Fix errors --- .../wvlet/airframe/http/netty/NettyRequestHandler.scala | 5 ++++- .../airframe/http/router/HttpEndpointExecutionContext.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index 2fbb9241fc..26936bec8a 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -137,7 +137,10 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi } private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = { - val isEventStream = resp.headers().get(HttpHeader.ContentType).contains("text/event-stream") + val isEventStream = + Option(resp.headers()) + .flatMap(h => Option(h.get(HttpHeader.ContentType))) + .exists(_.contains("text/event-stream")) val keepAlive: Boolean = HttpStatus.ofCode(resp.status().code()).isSuccessful && (HttpUtil.isKeepAlive(req) || isEventStream) diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala index 2f48c915cd..b401c751a1 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpEndpointExecutionContext.scala @@ -63,7 +63,7 @@ class HttpEndpointExecutionContext[Req: HttpRequestAdapter, Resp, F[_]]( result.asInstanceOf[F[Resp]] case valueCls if valueCls == classOf[ServerSentEvent] => // Rx[ServerSentEvent] - backend.toFuture(responseHandler.toHttpResponse(route, request, route.returnTypeSurface, result)) + backend.toFuture(responseHandler.toHttpResponse(route, request, route.returnTypeSurface, result)) case other => // If X is other type, convert X into an HttpResponse backend.mapF( From b2f0ef7d90355e0e6ddf41fb337f1213e9977548 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 22:35:36 -0800 Subject: [PATCH 5/6] Fix native code --- .../src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala b/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala index 86ec3eb887..43aecb7926 100644 --- a/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala +++ b/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala @@ -22,7 +22,7 @@ class RxBlockingQueue[A] extends RxSource[A]: private val blockingQueue = new LinkedBlockingQueue[RxEvent]() - override def add(event: RxEvent): Unit = + override def addEvent(event: RxEvent): Unit = blockingQueue.add(event) override def next: Rx[RxEvent] = Rx.const(blockingQueue.take()) From 983d0b8f1e8bea897348207dfda1f95af75ccbd9 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Thu, 6 Feb 2025 22:52:23 -0800 Subject: [PATCH 6/6] Revert name for backward compatibility --- .../wvlet/airframe/http/grpc/GrpcClient.scala | 12 ++++++------ .../airframe/http/grpc/GrpcClientCalls.scala | 6 +++--- .../http/grpc/internal/GrpcRequestHandler.scala | 16 ++++++++-------- .../wvlet/airframe/http/netty/SSETest.scala | 16 ++++++++-------- .../wvlet/airframe/rx/RxBlockingQueue.scala | 2 +- .../wvlet/airframe/rx/RxBlockingQueue.scala | 2 +- .../main/scala/wvlet/airframe/rx/RxQueue.scala | 2 +- .../main/scala/wvlet/airframe/rx/RxRunner.scala | 2 +- .../main/scala/wvlet/airframe/rx/RxSource.scala | 6 +++--- 9 files changed, 32 insertions(+), 32 deletions(-) diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala index 7120e3b82c..828c714f13 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClient.scala @@ -199,26 +199,26 @@ object GrpcClient extends LogSupport { new BlockingRxObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: A): Unit = { - toRx.addEvent(OnNext(v)) + toRx.add(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.addEvent(OnError(t)) + toRx.add(OnError(t)) } override def onCompleted(): Unit = { - toRx.addEvent(OnCompletion) + toRx.add(OnCompletion) } } private class RxObserver[A] extends StreamObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: A): Unit = { - toRx.addEvent(OnNext(v)) + toRx.add(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.addEvent(OnError(t)) + toRx.add(OnError(t)) } override def onCompleted(): Unit = { - toRx.addEvent(OnCompletion) + toRx.add(OnCompletion) } } diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala index 2b69eb098d..4a3cc2679b 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/GrpcClientCalls.scala @@ -34,13 +34,13 @@ object GrpcClientCalls extends LogSupport { new BlockingStreamObserver[A] { val toRx: RxBlockingQueue[A] = new RxBlockingQueue[A] override def onNext(v: Any): Unit = { - toRx.addEvent(OnNext(v)) + toRx.add(OnNext(v)) } override def onError(t: Throwable): Unit = { - toRx.addEvent(OnError(t)) + toRx.add(OnError(t)) } override def onCompleted(): Unit = { - toRx.addEvent(OnCompletion) + toRx.add(OnCompletion) } } diff --git a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala index 592a4e0199..9dbb148ad3 100644 --- a/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala +++ b/airframe-http-grpc/src/main/scala/wvlet/airframe/http/grpc/internal/GrpcRequestHandler.scala @@ -195,21 +195,21 @@ class GrpcRequestHandler( invokeServerMethod Try(readStreamingInput(grpcContext, codec, value)) match { case Success(v) => - rx.addEvent(OnNext(v)) + rx.add(OnNext(v)) case Failure(e) => reportError(e) - rx.addEvent(OnError(e)) + rx.add(OnError(e)) } } override def onError(t: Throwable): Unit = { reportError(t) requestLogger.logError(t, grpcContext, rpcContext) - rx.addEvent(OnError(t)) + rx.add(OnError(t)) responseObserver.onError(GrpcException.wrap(t)) } override def onCompleted(): Unit = { invokeServerMethod - rx.addEvent(OnCompletion) + rx.add(OnCompletion) promise.future.onComplete { case Success(value) => requestLogger.logRPC(grpcContext, rpcContext) @@ -261,19 +261,19 @@ class GrpcRequestHandler( Try(readStreamingInput(grpcContext, codec, value)) match { case Success(v) => // Add a log for each client-side stream message - rx.addEvent(OnNext(v)) + rx.add(OnNext(v)) case Failure(e) => - rx.addEvent(OnError(e)) + rx.add(OnError(e)) } } override def onError(t: Throwable): Unit = { reportError(t) - rx.addEvent(OnError(t)) + rx.add(OnError(t)) responseObserver.onError(t) } override def onCompleted(): Unit = { invokeServerMethod - rx.addEvent(OnCompletion) + rx.add(OnCompletion) promise.future.onComplete { case Success(v) => v match { diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala index 8fcd09cd40..7b2112c6d3 100644 --- a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/SSETest.scala @@ -52,17 +52,17 @@ class SSEApi { val queue = new RxBlockingQueue[ServerSentEvent]() new Thread(new Runnable { override def run(): Unit = { - queue.add(ServerSentEvent(data = "hello stream")) + queue.put(ServerSentEvent(data = "hello stream")) // Thread.sleep(100) - queue.add(ServerSentEvent(data = "another stream message\nwith two lines")) + queue.put(ServerSentEvent(data = "another stream message\nwith two lines")) // Thread.sleep(50) - queue.add(ServerSentEvent(event = Some("custom-event"), data = "hello custom event")) + queue.put(ServerSentEvent(event = Some("custom-event"), data = "hello custom event")) Thread.sleep(20) - queue.add(ServerSentEvent(id = Some("123"), data = "hello again")) + queue.put(ServerSentEvent(id = Some("123"), data = "hello again")) Thread.sleep(10) - queue.add(ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2")) + queue.put(ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2")) Thread.sleep(30) - queue.add(ServerSentEvent(retry = Some(1000), data = "need to retry")) + queue.put(ServerSentEvent(retry = Some(1000), data = "need to retry")) queue.stop() } }).start() @@ -92,7 +92,7 @@ class SSETest extends AirSpec { queue.stop() } override def onEvent(e: ServerSentEvent): Unit = { - queue.add(e) + queue.put(e) } }) ) @@ -129,7 +129,7 @@ class SSETest extends AirSpec { } override def onEvent(e: ServerSentEvent): Unit = { debug(e) - queue.add(e) + queue.put(e) } }) ) diff --git a/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala b/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala index 368c809cb1..009375785c 100644 --- a/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala +++ b/airframe-rx/.jvm/src/main/scala/wvlet/airframe/rx/RxBlockingQueue.scala @@ -22,7 +22,7 @@ class RxBlockingQueue[A] extends RxSource[A] { private val blockingQueue = new LinkedBlockingQueue[RxEvent]() - override def addEvent(event: RxEvent): Unit = { + override def add(event: RxEvent): Unit = { blockingQueue.add(event) } override def next: Rx[RxEvent] = { diff --git a/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala b/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala index 43aecb7926..86ec3eb887 100644 --- a/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala +++ b/airframe-rx/.native/src/main/scala-3/wvlet/airframe/rx/RxBlockingQueue.scala @@ -22,7 +22,7 @@ class RxBlockingQueue[A] extends RxSource[A]: private val blockingQueue = new LinkedBlockingQueue[RxEvent]() - override def addEvent(event: RxEvent): Unit = + override def add(event: RxEvent): Unit = blockingQueue.add(event) override def next: Rx[RxEvent] = Rx.const(blockingQueue.take()) diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala index 655f12b290..bcfb192d9e 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxQueue.scala @@ -28,7 +28,7 @@ class RxQueue[A]() extends RxSource[A] with LogSupport { private var queue = scala.collection.immutable.Queue.empty[RxEvent] private var waiting: Option[Promise[RxEvent]] = None - override def addEvent(event: RxEvent): Unit = { + override def add(event: RxEvent): Unit = { synchronized { queue = queue.enqueue(event) waiting match { diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala index 75431a77a4..62e1b41b00 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala @@ -576,7 +576,7 @@ class RxRunner( c1, Cancelable { () => toContinue = false - source.addEvent(OnError(new InterruptedException("cancelled"))) + source.add(OnError(new InterruptedException("cancelled"))) } ) } diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala index aeacace6f7..c6cef0ddc4 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxSource.scala @@ -17,8 +17,8 @@ package wvlet.airframe.rx * Rx implementation where the data is provided from an external process. */ trait RxSource[A] extends Rx[A] { - def add(e: A): Unit = addEvent(OnNext(e)) - def addEvent(ev: RxEvent): Unit + def put(e: A): Unit = add(OnNext(e)) + def add(ev: RxEvent): Unit def next: Rx[RxEvent] - def stop(): Unit = addEvent(OnCompletion) + def stop(): Unit = add(OnCompletion) }