Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] netty-cats: support multipart requests #3933

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ class AkkaHttpServerTest extends TestSuite with EitherValues {
def drainAkka(stream: AkkaStreams.BinaryStream): Future[Unit] =
stream.runWith(Sink.ignore).map(_ => ())

new AllServerTests(createServerTest, interpreter, backend).tests() ++
new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, akka-http rejects content-length with transfer-encoding
new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++
new ServerWebSocketTests(
createServerTest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class ArmeriaCatsServerTest extends TestSuite {
def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] =
stream.compile.drain.void

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class ArmeriaFutureServerTest extends TestSuite {
val interpreter = new ArmeriaTestFutureServerInterpreter()
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ class ArmeriaZioServerTest extends TestSuite {
def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] =
zStream.run(ZSink.drain)

new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false)
.tests() ++
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding
new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class JdkHttpServerTest extends TestSuite with EitherValues {
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false).tests()
new ServerMultipartTests(createServerTest, chunkingSupport = false)
.tests() ++ // chunking disabled, backend rejects content-length with transfer-encoding
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false).tests()
})
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package sttp.tapir.server.netty.cats.internal

import cats.effect.Async
import cats.effect.kernel.{Resource, Sync}
import cats.syntax.all._
import fs2.Chunk
import fs2.interop.reactivestreams.StreamSubscriber
import fs2.io.file.{Files, Path}
import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.multipart.{DefaultHttpDataFactory, HttpData, HttpPostRequestDecoder}
import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities.StreamMaxLengthExceededException
import sttp.capabilities.fs2.Fs2Streams
import sttp.model.Part
import sttp.monad.MonadError
import sttp.tapir.TapirFile
import sttp.tapir.integ.cats.effect.CatsMonadError
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible}
import sttp.capabilities.WebSockets
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File


private[cats] class NettyCatsRequestBody[F[_]: Async](
val createFile: ServerRequest => F[TapirFile],
Expand All @@ -21,6 +30,63 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](

override implicit val monad: MonadError[F] = new CatsMonadError()

def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]] = {
fs2.Stream
.resource(
Resource.make(Sync[F].delay(new HttpPostRequestDecoder(NettyCatsRequestBody.multiPartDataFactory, nettyRequest)))(d =>
Sync[F].blocking(d.destroy()) // after the stream finishes or fails, decoder data has to be cleaned up
)
)
.flatMap { decoder =>
fs2.Stream
.eval(StreamSubscriber[F, HttpContent](bufferSize = 1))
.flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s))))
.evalMapAccumulate({
(decoder, 0L)
})({ case ((decoder, processedBytesNum), httpContent) =>
monad
.blocking {
val newProcessedBytes = if (httpContent.content() != null) {
val processedBytesAndContentBytes = processedBytesNum + httpContent.content().readableBytes()
maxBytes.foreach { max =>
if (max < processedBytesAndContentBytes) {
throw new StreamMaxLengthExceededException(max)
}
}
processedBytesAndContentBytes
} else processedBytesNum

// this operation is the one that does potential I/O (writing files)
decoder.offer(httpContent)
val parts = Stream
.continually(if (decoder.hasNext) {
val next = decoder.next()
next
} else null)
.takeWhile(_ != null)
.toVector

(
(decoder, newProcessedBytes),
parts
)
}
})
.map(_._2)
.map(_.flatMap(p => m.partType(p.getName()).map((p, _)).toList))
.evalMap(_.traverse { case (data, partType) => toRawPart(serverRequest, data, partType).map(_.asInstanceOf[Part[Any]]) })
}
.compile
.toVector
.map(_.flatten)
.map(RawValue.fromParts(_))
}

override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] =
streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte])

Expand All @@ -32,4 +98,13 @@ private[cats] class NettyCatsRequestBody[F[_]: Async](
)
.compile
.drain

override def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit] =
fs2.Stream.emits(bytes).through(Files.forAsync[F].writeAll(Path.fromNioPath(file.toPath))).compile.drain

}

private[cats] object NettyCatsRequestBody {
val multiPartDataFactory =
new DefaultHttpDataFactory() // writes to memory, then switches to disk if exceeds MINSIZE (16kB), check other constructors.
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import sttp.tapir.tests.{Test, TestSuite}

import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import org.scalatest.matchers.should.Matchers

class NettyCatsServerTest extends TestSuite with EitherValues {
class NettyCatsServerTest extends TestSuite with EitherValues with Matchers {

override def tests: Resource[IO, List[Test]] =
backendResource.flatMap { backend =>
Expand Down Expand Up @@ -41,6 +42,12 @@ class NettyCatsServerTest extends TestSuite with EitherValues {
new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++
new NettyFs2StreamingCancellationTest(createServerTest).tests() ++
new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++
new ServerMultipartTests(
createServerTest,
partContentTypeHeaderSupport = false,
partOtherHeaderSupport = false,
multipartResponsesSupport = false
).tests() ++
new ServerWebSocketTests(
createServerTest,
Fs2Streams[IO],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@ import org.playframework.netty.http.StreamedHttpRequest
import org.reactivestreams.Publisher
import sttp.capabilities
import sttp.monad.{FutureMonad, MonadError}
import sttp.tapir.TapirFile
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.netty.internal.reactivestreams._
import sttp.tapir.{RawBodyType, RawPart, TapirFile}

import java.io.File
import scala.concurrent.{ExecutionContext, Future}

private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext)
extends NettyRequestBody[Future, NoStreams] {

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): Future[RawValue[Seq[RawPart]]] = Future.failed(new UnsupportedOperationException("Multipart requests not supported."))

override def writeBytesToFile(bytes: Array[Byte], file: File): Future[Unit] = Future.failed(new UnsupportedOperationException)

override val streams: capabilities.Streams[NoStreams] = NoStreams
override implicit val monad: MonadError[Future] = new FutureMonad()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import java.io.InputStream
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import sttp.tapir.RawPart
import io.netty.handler.codec.http.multipart.InterfaceHttpData
import sttp.model.Part
import io.netty.handler.codec.http.multipart.HttpData
import io.netty.handler.codec.http.multipart.FileUpload
import java.io.ByteArrayInputStream
import java.io.File

/** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */
private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] {

Expand All @@ -37,6 +46,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]]

/** Reads the reactive stream emitting HttpData into a vector of parts. Implementation-specific, as file manipulations and stream
* processing logic can be different for different backends.
*/
def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): F[RawValue[Seq[RawPart]]]

/** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file.
*
* @param serverRequest
Expand All @@ -50,6 +69,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
*/
def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit]

def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit]

override def toRaw[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType[RAW],
Expand All @@ -70,8 +91,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
file <- createFile(serverRequest)
_ <- writeToFile(serverRequest, file, maxBytes)
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException)
case m: RawBodyType.MultipartBody =>
publisherToMultipart(serverRequest.underlying.asInstanceOf[StreamedHttpRequest], serverRequest, m, maxBytes)
}

private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] =
Expand All @@ -96,4 +117,72 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody
throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}")
}
}

protected def toRawPart[R](
serverRequest: ServerRequest,
data: InterfaceHttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val partName = data.getName()
data match {
case httpData: HttpData =>
// TODO filename* attribute is not used by netty. Non-ascii filenames like https://github.com/http4s/http4s/issues/5809 are unsupported.
toRawPartHttpData(partName, serverRequest, httpData, partType)
case unsupportedDataType =>
monad.error(new UnsupportedOperationException(s"Unsupported multipart data type: $unsupportedDataType in part $partName"))
}
}

private def toRawPartHttpData[R](
partName: String,
serverRequest: ServerRequest,
httpData: HttpData,
partType: RawBodyType[R]
): F[Part[R]] = {
val fileName = httpData match {
case fileUpload: FileUpload => Option(fileUpload.getFilename())
case _ => None
}
partType match {
case RawBodyType.StringBody(defaultCharset) =>
// TODO otherDispositionParams not supported. They are normally a part of the content-disposition part header, but this header is not directly accessible, they are extracted internally by the decoder.
val charset = if (httpData.getCharset() != null) httpData.getCharset() else defaultCharset
readHttpData(httpData, _.getString(charset)).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteArrayBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.ByteBufferBody =>
readHttpData(httpData, _.get()).map(body => Part(partName, ByteBuffer.wrap(body), fileName = fileName))
case RawBodyType.InputStreamBody =>
(if (httpData.isInMemory())
monad.unit(new ByteArrayInputStream(httpData.get()))
else {
monad.blocking(java.nio.file.Files.newInputStream(httpData.getFile().toPath()))
}).map(body => Part(partName, body, fileName = fileName))
case RawBodyType.InputStreamRangeBody =>
val body = () => {
if (httpData.isInMemory())
new ByteArrayInputStream(httpData.get())
else
java.nio.file.Files.newInputStream(httpData.getFile().toPath())
}
monad.unit(Part(partName, InputStreamRange(body), fileName = fileName))
case RawBodyType.FileBody =>
val fileF: F[File] =
if (httpData.isInMemory())
(for {
file <- createFile(serverRequest)
_ <- writeBytesToFile(httpData.get(), file)
} yield file)
else monad.unit(httpData.getFile())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: Netty decoder creates the file (if it's > 16KB), so we can just get its handle here.

fileF.map(file => Part(partName, FileRange(file), fileName = fileName))
case _: RawBodyType.MultipartBody =>
monad.error(new UnsupportedOperationException(s"Nested multipart not supported, part name = $partName"))
}
}

private def readHttpData[T](httpData: HttpData, f: HttpData => T): F[T] =
if (httpData.isInMemory())
monad.unit(f(httpData))
else
monad.blocking(f(httpData))
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import sttp.tapir.model.ServerRequest
import sttp.tapir.server.netty.internal.NettyRequestBody
import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber}
import sttp.tapir.server.netty.sync.*
import sttp.tapir.RawBodyType
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.RawPart
import java.io.File

private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Identity, OxStreams]:

Expand All @@ -20,6 +24,14 @@ private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirF
override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] =
SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes)

override def publisherToMultipart(
nettyRequest: StreamedHttpRequest,
serverRequest: ServerRequest,
m: RawBodyType.MultipartBody,
maxBytes: Option[Long]
): RawValue[Seq[RawPart]] = throw new UnsupportedOperationException("Multipart requests not supported.")
override def writeBytesToFile(bytes: Array[Byte], file: File) = throw new UnsupportedOperationException()

override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit =
serverRequest.underlying match
case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes)
Expand Down
Loading
Loading