Skip to content

Commit

Permalink
feat: Add zio-http multipart body support (#3690)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Warski <[email protected]>
  • Loading branch information
seakayone and adamw authored Sep 17, 2024
1 parent b87d150 commit 9a3b92a
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,15 @@ trait ZioHttpInterpreter[R] {
}
val statusCode = resp.code.code

ZIO.succeed(
Response(
status = Status.fromInt(statusCode),
headers = ZioHttpHeaders(allHeaders),
body = body
.map {
case ZioStreamHttpResponseBody(stream, Some(contentLength)) => Body.fromStream(stream, contentLength)
case ZioStreamHttpResponseBody(stream, None) => Body.fromStreamChunked(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
body
.map {
case ZioStreamHttpResponseBody(stream, Some(contentLength)) => ZIO.succeed(Body.fromStream(stream, contentLength))
case ZioStreamHttpResponseBody(stream, None) => ZIO.succeed(Body.fromStreamChunked(stream))
case ZioMultipartHttpResponseBody(formFields) => Body.fromMultipartFormUUID(Form(Chunk.fromIterable(formFields)))
case ZioRawHttpResponseBody(chunk, _) => ZIO.succeed(Body.fromChunk(chunk))
}
.getOrElse(ZIO.succeed(Body.empty))
.map(zioBody => Response(status = Status.fromInt(statusCode), headers = ZioHttpHeaders(allHeaders), body = zioBody))
}

private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): Seq[ZioHttpHeader] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,96 @@ package sttp.tapir.server.ziohttp

import sttp.capabilities
import sttp.capabilities.zio.ZioStreams
import sttp.model.Part
import sttp.model.Part.FileNameDispositionParam
import sttp.tapir.FileRange
import sttp.tapir.InputStreamRange
import sttp.tapir.RawBodyType
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType}
import zio.http.Request
import zio.stream.{Stream, ZSink, ZStream}
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.interpreter.RequestBody
import zio.{RIO, Task, ZIO}
import zio.http.{FormField, Request, StreamingForm}
import zio.http.FormField.StreamingBinary
import zio.stream.ZSink
import zio.stream.ZStream

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] {
override val streams: capabilities.Streams[ZioStreams] = ZioStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = {
override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] =
toRaw(serverRequest, zStream(serverRequest), bodyType, maxBytes)

def asByteArray: Task[Array[Byte]] =
(toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).runCollect.map(_.toArray)
private def toRaw[RAW](
serverRequest: ServerRequest,
stream: ZStream[Any, Throwable, Byte],
bodyType: RawBodyType[RAW],
maxBytes: Option[Long]
): Task[RawValue[RAW]] = {
val limitedStream = limitedZStream(stream, maxBytes)
val asByteArray = limitedStream.runCollect.map(_.toArray)

bodyType match {
case RawBodyType.StringBody(defaultCharset) => asByteArray.map(new String(_, defaultCharset)).map(RawValue(_))
case RawBodyType.ByteArrayBody => asByteArray.map(RawValue(_))
case RawBodyType.ByteBufferBody => asByteArray.map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_))
case RawBodyType.InputStreamBody => asByteArray.map(new ByteArrayInputStream(_)).map(RawValue(_))
case RawBodyType.InputStreamRangeBody =>
asByteArray.map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
asByteArray.map(bytes => InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
case RawBodyType.FileBody =>
for {
file <- serverOptions.createFile(serverRequest)
_ <- (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).run(ZSink.fromFile(file)).map(_ => ())
_ <- limitedStream.run(ZSink.fromFile(file)).unit
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported"))
case m: RawBodyType.MultipartBody => handleMultipartBody(serverRequest, m, limitedStream)
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val inputStream = stream(serverRequest)
maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream]
private def handleMultipartBody[RAW](
serverRequest: ServerRequest,
bodyType: RawBodyType.MultipartBody,
limitedStream: ZStream[Any, Throwable, Byte]
): Task[RawValue[RAW]] = {
zRequest(serverRequest).body.contentType.flatMap(_.boundary) match {
case Some(boundary) =>
StreamingForm(limitedStream, boundary).fields
.flatMap(field => ZStream.fromIterable(bodyType.partType(field.name).map((field, _))))
.mapZIO { case (field, bodyType) => toRawPart(serverRequest, field, bodyType) }
.runCollect
.map(RawValue.fromParts(_).asInstanceOf[RawValue[RAW]])
case None =>
ZIO.fail(
new IllegalStateException("Cannot decode body as streaming multipart/form-data without a known boundary")
)
}
}

private def toRawPart[A](serverRequest: ServerRequest, field: FormField, bodyType: RawBodyType[A]): Task[Part[A]] = {
val fieldsStream = field match {
case StreamingBinary(_, _, _, _, s) => s
case _ => ZStream.fromIterableZIO(field.asChunk)
}
toRaw(serverRequest, fieldsStream, bodyType, None)
.map(raw =>
Part(
field.name,
raw.value,
otherDispositionParams = field.filename.map(name => Map(FileNameDispositionParam -> name)).getOrElse(Map.empty)
).contentType(field.contentType.fullType)
)
}

private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] =
zioHttpRequest(serverRequest).body.asStream
override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream =
limitedZStream(zStream(serverRequest), maxBytes).asInstanceOf[streams.BinaryStream]

private def zRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]

private def limitedZStream(stream: ZStream[Any, Throwable, Byte], maxBytes: Option[Long]) = {
maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream)
}

private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]
private def zStream(serverRequest: ServerRequest) = zRequest(serverRequest).body.asStream
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sttp.tapir.server.ziohttp

import zio.stream.ZStream
import zio.Chunk
import zio.http.FormField

sealed trait ZioHttpResponseBody {
def contentLength: Option[Long]
Expand All @@ -10,3 +11,7 @@ sealed trait ZioHttpResponseBody {
case class ZioStreamHttpResponseBody(stream: ZStream[Any, Throwable, Byte], contentLength: Option[Long]) extends ZioHttpResponseBody

case class ZioRawHttpResponseBody(bytes: Chunk[Byte], contentLength: Option[Long]) extends ZioHttpResponseBody

case class ZioMultipartHttpResponseBody(formFields: List[FormField]) extends ZioHttpResponseBody {
override def contentLength: Option[Long] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package sttp.tapir.server.ziohttp

import sttp.capabilities.zio.ZioStreams
import sttp.model.HasHeaders
import sttp.model.Part
import sttp.tapir.server.interpreter.ToResponseBody
import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput}
import sttp.tapir.{CodecFormat, RawBodyType, RawPart, WebSocketBodyOutput}
import zio.Chunk
import zio.http.FormField
import zio.http.MediaType
import zio.stream.ZStream

import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.charset.Charset

Expand Down Expand Up @@ -74,6 +78,59 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams]
}
}
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
case m @ RawBodyType.MultipartBody(_, _) =>
val formFields = (r: Seq[RawPart]).flatMap { part =>
m.partType(part.name).map { partType =>
toFormField(partType.asInstanceOf[RawBodyType[Any]], part)
}
}.toList
ZioMultipartHttpResponseBody(formFields)
}

private def toFormField[R](bodyType: RawBodyType[R], part: Part[R]): FormField = {
val mediaType: Option[MediaType] = part.contentType.flatMap(MediaType.forContentType)
bodyType match {
case RawBodyType.StringBody(_) =>
FormField.Text(part.name, part.body, mediaType.getOrElse(MediaType.text.plain), part.fileName)
case RawBodyType.ByteArrayBody =>
FormField.Binary(
part.name,
Chunk.fromArray(part.body),
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.ByteBufferBody =>
val array: Array[Byte] = new Array[Byte](part.body.remaining)
part.body.get(array)
FormField.Binary(
part.name,
Chunk.fromArray(array),
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.FileBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromFile(part.body.file).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.InputStreamBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromInputStream(part.body).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case RawBodyType.InputStreamRangeBody =>
FormField.streamingBinaryField(
part.name,
ZStream.fromInputStream(part.body.inputStream()).orDie,
mediaType.getOrElse(MediaType.application.`octet-stream`),
filename = part.fileName
)
case _: RawBodyType.MultipartBody =>
throw new UnsupportedOperationException("Nested multipart messages are not supported.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,10 @@ class ZioHttpServerTest extends TestSuite {
interpreter,
backend,
basic = false,
staticContent = true,
multipart = false,
file = true,
options = false
).tests() ++
new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++
new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++
new ZioHttpCompositionTest(createServerTest).tests() ++
new ServerWebSocketTests(
Expand Down

0 comments on commit 9a3b92a

Please sign in to comment.