Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http (feature): Support SSE endpoint with Rx[ServerSentEvent] return type #3818

Merged
merged 6 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import wvlet.airframe.http.{
HttpStatus,
RPCException,
RPCStatus,
ServerAddress
ServerAddress,
ServerSentEvent
}
import wvlet.airframe.rx.{OnCompletion, OnError, OnNext, Rx, RxRunner}
import wvlet.log.LogSupport
Expand Down Expand Up @@ -103,8 +104,22 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi

RxRunner.run(rxResponse) {
case OnNext(v) =>
val nettyResponse = toNettyResponse(v.asInstanceOf[Response])
val resp = v.asInstanceOf[Response]
val nettyResponse = toNettyResponse(resp)
writeResponse(msg, ctx, nettyResponse)

if (resp.isContentTypeEventStream && resp.message.isEmpty) {
// Read SSE stream
val c = RxRunner.runContinuously(resp.events) {
case OnNext(e: ServerSentEvent) =>
val event = e.toContent
val buf = Unpooled.copiedBuffer(event.getBytes("UTF-8"))
ctx.writeAndFlush(new DefaultHttpContent(buf))
case _ =>
val f = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)
f.addListener(ChannelFutureListener.CLOSE)
}
}
case OnError(ex) =>
// This path manages unhandled exceptions
val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse
Expand All @@ -122,7 +137,14 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
}

private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = {
val keepAlive = HttpStatus.ofCode(resp.status().code()).isSuccessful && HttpUtil.isKeepAlive(req)
val isEventStream =
Option(resp.headers())
.flatMap(h => Option(h.get(HttpHeader.ContentType)))
.exists(_.contains("text/event-stream"))

val keepAlive: Boolean =
HttpStatus.ofCode(resp.status().code()).isSuccessful && (HttpUtil.isKeepAlive(req) || isEventStream)

if (keepAlive) {
if (!req.protocolVersion().isKeepAliveDefault) {
resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE)
Expand All @@ -139,8 +161,15 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
}

object NettyRequestHandler {
def toNettyResponse(response: Response): DefaultFullHttpResponse = {
val r = if (response.message.isEmpty) {
def toNettyResponse(response: Response): DefaultHttpResponse = {
val r = if (response.isContentTypeEventStream && response.message.isEmpty) {
val res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode))
res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/event-stream")
res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
res.headers().set(HttpHeaderNames.CACHE_CONTROL, HttpHeaderValues.NO_CACHE)
res.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE)
res
} else if (response.message.isEmpty) {
val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(response.statusCode))
// Need to set the content length properly to return the response in Netty
HttpUtil.setContentLength(res, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ package wvlet.airframe.http.netty

import wvlet.airframe.codec.{JSONCodec, MessageCodec, MessageCodecFactory}
import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.{Http, HttpStatus}
import wvlet.airframe.http.{Http, HttpStatus, ServerSentEvent}
import wvlet.airframe.http.router.{ResponseHandler, Route}
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.airframe.surface.{Primitive, Surface}
import wvlet.airframe.rx.Rx
import wvlet.log.LogSupport

class NettyResponseHandler extends ResponseHandler[Request, Response] with LogSupport {
Expand All @@ -36,6 +37,10 @@ class NettyResponseHandler extends ResponseHandler[Request, Response] with LogSu
case s: String if !request.acceptsMsgPack =>
newResponse(route, request, responseSurface)
.withContent(s)
case r: Rx[_] if responseSurface.typeArgs(0).rawType == classOf[ServerSentEvent] =>
val resp = newResponse(route, request, responseSurface).withContentType("text/event-stream")
resp.events = r.asInstanceOf[Rx[ServerSentEvent]]
resp
case _ =>
val rs = codecFactory.of(responseSurface)
val msgpack: Array[Byte] = rs match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
*/
package wvlet.airframe.http.netty

import wvlet.airframe.http.{Endpoint, Http, RxRouter, ServerSentEvent, ServerSentEventHandler}
import wvlet.airframe.http.{Endpoint, Http, HttpMethod, RxRouter, ServerSentEvent, ServerSentEventHandler}
import wvlet.airframe.http.HttpMessage.Response
import wvlet.airframe.http.client.AsyncClient
import wvlet.airframe.rx.Rx
import wvlet.airframe.rx.{Rx, RxBlockingQueue}
import wvlet.airspec.AirSpec

class SSEApi {
Expand Down Expand Up @@ -46,6 +46,28 @@ class SSEApi {
|data: need to retry
|""".stripMargin)
}

@Endpoint(method = HttpMethod.POST, path = "/v1/sse-stream")
def sseStream(): Rx[ServerSentEvent] = {
val queue = new RxBlockingQueue[ServerSentEvent]()
new Thread(new Runnable {
override def run(): Unit = {
queue.put(ServerSentEvent(data = "hello stream"))
// Thread.sleep(100)
queue.put(ServerSentEvent(data = "another stream message\nwith two lines"))
// Thread.sleep(50)
queue.put(ServerSentEvent(event = Some("custom-event"), data = "hello custom event"))
Thread.sleep(20)
queue.put(ServerSentEvent(id = Some("123"), data = "hello again"))
Thread.sleep(10)
queue.put(ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"))
Thread.sleep(30)
queue.put(ServerSentEvent(retry = Some(1000), data = "need to retry"))
queue.stop()
}
}).start()
queue
}
}

class SSETest extends AirSpec {
Expand All @@ -57,43 +79,78 @@ class SSETest extends AirSpec {
)
}

test("read sse events") { (client: AsyncClient) =>
val buf = List.newBuilder[ServerSentEvent]
val completed = Rx.variable(false)
test("read sse-events") { (client: AsyncClient) =>
val queue = new RxBlockingQueue[ServerSentEvent]()
val rx = client.send(
Http
.GET("/v1/sse")
.withEventHandler(new ServerSentEventHandler {
override def onError(e: Throwable): Unit = {
completed := true
queue.stop()
}
override def onCompletion(): Unit = {
completed := true
queue.stop()
}
override def onEvent(e: ServerSentEvent): Unit = {
buf += e
queue.put(e)
}
})
)
rx.join(completed)
.filter(_._2 == true)
.map(_._1)
.map { resp =>
resp.statusCode shouldBe 200

val events = buf.result()
val expected = List(
ServerSentEvent(data = "hello stream"),
ServerSentEvent(data = "another stream message\nwith two lines"),
ServerSentEvent(event = Some("custom-event"), data = "hello custom event"),
ServerSentEvent(id = Some("123"), data = "hello again"),
ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"),
ServerSentEvent(retry = Some(1000), data = "need to retry")
)
rx.map { resp =>
resp.statusCode shouldBe 200

trace(events.mkString("\n"))
trace(expected.mkString("\n"))
events shouldBe expected
}
val events = queue.toSeq.toList
val expected = List(
ServerSentEvent(data = "hello stream"),
ServerSentEvent(data = "another stream message\nwith two lines"),
ServerSentEvent(event = Some("custom-event"), data = "hello custom event"),
ServerSentEvent(id = Some("123"), data = "hello again"),
ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"),
ServerSentEvent(retry = Some(1000), data = "need to retry")
)

trace(events.mkString("\n"))
events shouldBe expected
}
}

test("read sse-stream") { (client: AsyncClient) =>
val queue = new RxBlockingQueue[ServerSentEvent]()
val rx = client.send(
Http
.POST("/v1/sse-stream")
.withEventHandler(new ServerSentEventHandler {
override def onError(e: Throwable): Unit = {
queue.stop()
}
override def onCompletion(): Unit = {
queue.stop()
}
override def onEvent(e: ServerSentEvent): Unit = {
debug(e)
queue.put(e)
}
})
)

rx.map { resp =>
resp.statusCode shouldBe 200

val events = queue.toSeq.toList
val expected = List(
ServerSentEvent(data = "hello stream"),
ServerSentEvent(data = "another stream message\nwith two lines"),
ServerSentEvent(event = Some("custom-event"), data = "hello custom event"),
ServerSentEvent(id = Some("123"), data = "hello again"),
ServerSentEvent(id = Some("1234"), event = Some("custom-event"), data = "hello again 2"),
ServerSentEvent(retry = Some(1000), data = "need to retry")
)

trace(events.mkString("\n"))
// trace(expected.mkString("\n"))
events shouldBe expected
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class JavaHttpClientChannel(val destination: ServerAddress, private[http] val co
executor.execute(new Runnable {
override def run(): Unit = {
try {
withResource(new BufferedReader(new InputStreamReader(httpResponse.body()))) { reader =>
val body = httpResponse.body()
withResource(new BufferedReader(new InputStreamReader(body))) { reader =>
var id: Option[String] = None
var eventType: Option[String] = None
var retry: Option[Long] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package wvlet.airframe.http.router
import java.lang.reflect.InvocationTargetException
import wvlet.airframe.codec.{MessageCodec, MessageCodecFactory}
import wvlet.airframe.control.ThreadUtil
import wvlet.airframe.http.{HttpBackend, HttpContext, HttpRequestAdapter}
import wvlet.airframe.http.{Http, HttpBackend, HttpContext, HttpMethod, HttpRequestAdapter, HttpStatus, ServerSentEvent}
import wvlet.log.LogSupport

import java.util.concurrent.Executors
Expand Down Expand Up @@ -61,6 +61,9 @@ class HttpEndpointExecutionContext[Req: HttpRequestAdapter, Resp, F[_]](
case valueCls if backend.isRawResponseType(valueCls) =>
// Use Backend Future (e.g., Finagle Future or Rx)
result.asInstanceOf[F[Resp]]
case valueCls if valueCls == classOf[ServerSentEvent] =>
// Rx[ServerSentEvent]
backend.toFuture(responseHandler.toHttpResponse(route, request, route.returnTypeSurface, result))
case other =>
// If X is other type, convert X into an HttpResponse
backend.mapF(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ object HttpMessage {
case class Response(
status: HttpStatus = HttpStatus.Ok_200,
header: HttpMultiMap = HttpMultiMap.empty,
message: Message = EmptyMessage
message: Message = EmptyMessage,
private[http] var events: Rx[ServerSentEvent] = Rx.empty
) extends HttpMessage[Response] {
override def toString: String = s"Response(${status},${header})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@ case class ServerSentEvent(
retry: Option[Long] = None,
// event data string. If multiple data entries are reported, concatenated with newline
data: String
)
) {
def toContent: String = {
val b = Seq.newBuilder[String]
id.foreach(x => b += s"id: $x")
event.foreach(x => b += s"event: $x")
retry.foreach(x => b += s"retry: $x")
data.split("\n").foreach(x => b += s"data: $x")
s"${b.result().mkString("\n")}\n\n"
}
}

object ServerSentEventHandler {
def empty: ServerSentEventHandler = new ServerSentEventHandler {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package wvlet.airframe.rx
* Rx implementation where the data is provided from an external process.
*/
trait RxSource[A] extends Rx[A] {
def put(e: A): Unit = add(OnNext(e))
def add(ev: RxEvent): Unit
def next: Rx[RxEvent]
def stop(): Unit = add(OnCompletion)
Expand Down
Loading