Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] netty-cats: support multipart requests #3933

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
package sttp.tapir.server.netty.cats.internal

import cats.effect.Async
import cats.effect.kernel.{Resource, Sync}
import cats.syntax.all._
import fs2.Chunk
import fs2.interop.reactivestreams.StreamSubscriber
import fs2.io.file.{Files, Path}
import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.multipart.{DefaultHttpDataFactory, HttpData, HttpPostRequestDecoder}
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.StreamMaxLengthExceededException
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.Part
import sttp.monad.MonadError
import sttp.tapir.TapirFile
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.capabilities.WebSockets
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File

private[cats] class NettyCatsRequestBody[F[_]: Async](
val createFile: ServerRequest => F[TapirFile],
Expand All @@ -21,6 +29,61 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]] = {
fs2.Stream
.resource(
Resource.make(Sync[F].delay(new HttpPostRequestDecoder(NettyCatsRequestBody.multiPartDataFactory, nettyRequest)))(d =>
Sync[F].blocking(d.destroy()) // after the stream finishes or fails, decoder data has to be cleaned up
)
)
.flatMap { decoder =>
fs2.Stream
.eval(StreamSubscriber[F, HttpContent](bufferSize = 1))
.flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s))))
.evalMapAccumulate({
(decoder, 0L)
})({ case ((decoder, processedBytesNum), httpContent) =>
monad
.blocking {
// this operation is the one that does potential I/O (writing files)
// TODO not thread-safe? (visibility of internal state changes?)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: This is the part that concerns me. The decoder is stateful, so it has an internal state changed with every call to offer. Each time a new HttpContent flows through the stream, the monad.blocking call gets a thread from the blocking pool, where it calls offer. There are no race conditions, but aren't internals of the decoder risking visibility issues because of such circumstances?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as long as there's some memory barrier along the way (see e.g. https://stackoverflow.com/questions/12438464/volatile-variables-and-other-variables). There should be a couple of those along the way, I suspect putting back/obtaining a thread to execute blocking ops itself imposes at least one such barrier.

decoder.offer(httpContent)
var processedBytesAndContentBytes = processedBytesNum

val parts = Stream
.continually(if (decoder.hasNext) {
val next = decoder.next()
processedBytesAndContentBytes = processedBytesAndContentBytes + next.asInstanceOf[HttpData].length()
maxBytes.foreach { max =>
if (max < processedBytesAndContentBytes) {
throw new StreamMaxLengthExceededException(max)
}
}
next
} else null)
.takeWhile(_ != null)
.toVector
(
(decoder, processedBytesAndContentBytes),
parts
)
}
})
.map(_._2)
.map(_.flatMap(p => m.partType(p.getName()).map((p, _)).toList))
.evalMap(_.traverse { case (data, partType) => toRawPart(serverRequest, data, partType).map(_.asInstanceOf[Part[Any]]) })
}
.compile
.toVector
.map(_.flatten)
.map(RawValue.fromParts(_))
}

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

Expand All @@ -32,4 +95,13 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](
)
.compile
.drain

override def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit] =
fs2.Stream.emits(bytes).through(Files.forAsync[F].writeAll(Path.fromNioPath(file.toPath))).compile.drain

}

private[cats] object NettyCatsRequestBody {
val multiPartDataFactory =
new DefaultHttpDataFactory() // writes to memory, then switches to disk if exceeds MINSIZE (16kB), check other constructors.
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import org.scalatest.matchers.should.Matchers

class NettyCatsServerTest extends TestSuite with EitherValues {
class NettyCatsServerTest extends TestSuite with EitherValues with Matchers {

override def tests: Resource[IO, List[Test]] =
backendResource.flatMap { backend =>
Expand Down Expand Up @@ -41,6 +42,12 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++
new ServerMultipartTests(
createServerTest,
partContentTypeHeaderSupport = false,
partOtherHeaderSupport = false,
multipartResponsesSupport = false
).tests() ++
new ServerWebSocketTests(
createServerTest,
Fs2Streams[IO],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@ import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities
import sttp.monad.{FutureMonad, MonadError}
import sttp.tapir.TapirFile
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.reactivestreams._
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File
import scala.concurrent.{ExecutionContext, Future}

private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext)
extends NettyRequestBody[Future, NoStreams] {

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): Future[RawValue[Seq[RawPart]]] = Future.failed(new UnsupportedOperationException("Multipart requests not supported."))

override def writeBytesToFile(bytes: Array[Byte], file: File): Future[Unit] = Future.failed(new UnsupportedOperationException)

override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import java.io.InputStream
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import sttp.tapir.RawPart
import io.netty.handler.codec.http.multipart.InterfaceHttpData
import sttp.model.Part
import io.netty.handler.codec.http.multipart.HttpData
import io.netty.handler.codec.http.multipart.FileUpload
import java.io.ByteArrayInputStream
import java.io.File

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {

Expand All @@ -37,6 +46,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]]

/** Reads the reactive stream emitting HttpData into a vector of parts. Implementation-specific, as file manipulations and stream
* processing logic can be different for different backends.
*/
def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
* @param serverRequest
Expand All @@ -50,6 +69,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit]

def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit]

override def toRaw[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType[RAW],
Expand All @@ -70,8 +91,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException)
case m: RawBodyType.MultipartBody =>
publisherToMultipart(serverRequest.underlying.asInstanceOf[StreamedHttpRequest], serverRequest, m, maxBytes)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
Expand All @@ -96,4 +117,72 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")
}
}

protected def toRawPart[R](
serverRequest: ServerRequest,
data: InterfaceHttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val partName = data.getName()
data match {
case httpData: HttpData =>
// TODO filename* attribute is not used by netty. Non-ascii filenames like https://github.com/http4s/http4s/issues/5809 are unsupported.
toRawPartHttpData(partName, serverRequest, httpData, partType)
case unsupportedDataType =>
monad.error(new UnsupportedOperationException(s"Unsupported multipart data type: $unsupportedDataType in part $partName"))
}
}

private def toRawPartHttpData[R](
partName: String,
serverRequest: ServerRequest,
httpData: HttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val fileName = httpData match {
case fileUpload: FileUpload => Option(fileUpload.getFilename())
case _ => None
}
partType match {
case RawBodyType.StringBody(defaultCharset) =>
// TODO otherDispositionParams not supported. They are normally a part of the content-disposition part header, but this header is not directly accessible, they are extracted internally by the decoder.
val charset = if (httpData.getCharset() != null) httpData.getCharset() else defaultCharset
readHttpData(httpData, _.getString(charset)).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteArrayBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteBufferBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, ByteBuffer.wrap(body), fileName = fileName))
case RawBodyType.InputStreamBody =>
(if (httpData.isInMemory())
monad.unit(new ByteArrayInputStream(httpData.get()))
else {
monad.blocking(java.nio.file.Files.newInputStream(httpData.getFile().toPath()))
}).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.InputStreamRangeBody =>
val body = () => {
if (httpData.isInMemory())
new ByteArrayInputStream(httpData.get())
else
java.nio.file.Files.newInputStream(httpData.getFile().toPath())
}
monad.unit(Part(partName, InputStreamRange(body), fileName = fileName))
case RawBodyType.FileBody =>
val fileF: F[File] =
if (httpData.isInMemory())
(for {
file <- createFile(serverRequest)
_ <- writeBytesToFile(httpData.get(), file)
} yield file)
else monad.unit(httpData.getFile())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: Netty decoder creates the file (if it's > 16KB), so we can just get its handle here.

fileF.map(file => Part(partName, FileRange(file), fileName = fileName))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException(s"Nested multipart not supported, part name = $partName"))
}
}

private def readHttpData[T](httpData: HttpData, f: HttpData => T): F[T] =
if (httpData.isInMemory())
monad.unit(f(httpData))
else
monad.blocking(f(httpData))
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import sttp.tapir.model.ServerRequest
import sttp.tapir.server.netty.internal.NettyRequestBody
import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber}
import sttp.tapir.server.netty.sync.*
import sttp.tapir.RawBodyType
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.RawPart
import java.io.File

private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Identity, OxStreams]:

Expand All @@ -20,6 +24,14 @@ private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirF
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): RawValue[Seq[RawPart]] = throw new UnsupportedOperationException("Multipart requests not supported.")
override def writeBytesToFile(bytes: Array[Byte], file: File) = throw new UnsupportedOperationException()

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
serverRequest.underlying match
case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package sttp.tapir.server.netty.zio.internal

import io.netty.handler.codec.http.HttpContent
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.zio.ZioStreams
import sttp.monad.MonadError
import sttp.tapir.TapirFile
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.tapir.ztapir.RIOMonadError
import zio.RIO
import sttp.tapir.{RawBodyType, RawPart, TapirFile}
import zio.stream._
import zio.{RIO, ZIO}

import java.io.File

private[zio] class NettyZioRequestBody[Env](
val createFile: ServerRequest => RIO[Env, TapirFile],
Expand All @@ -19,6 +23,14 @@ private[zio] class NettyZioRequestBody[Env](
override val streams: ZioStreams = ZioStreams
override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env]

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): RIO[Env, RawValue[Seq[RawPart]]] = ZIO.die(new UnsupportedOperationException("Multipart requests not supported."))

override def writeBytesToFile(bytes: Array[Byte], file: File): RIO[Env, Unit] = ZIO.die(new UnsupportedOperationException)
override def publisherToBytes(
publisher: Publisher[HttpContent],
contentLength: Option[Long],
Expand Down
Loading
Loading