Skip to content

Commit

Permalink
feat: Add multipart body support with zio-http
Browse files Browse the repository at this point in the history
  • Loading branch information
seakayone committed Apr 16, 2024
1 parent 01fb95f commit 2090a02
Showing 1 changed file with 59 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,90 @@ package sttp.tapir.server.ziohttp

import sttp.capabilities
import sttp.capabilities.zio.ZioStreams
import sttp.model.Part
import sttp.model.Part.FileNameDispositionParam
import sttp.tapir.FileRange
import sttp.tapir.InputStreamRange
import sttp.tapir.RawBodyType
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType}
import sttp.tapir.server.interpreter.RawValue
import sttp.tapir.server.interpreter.RequestBody
import zio.RIO
import zio.Task
import zio.http.FormField
import zio.http.FormField.StreamingBinary
import zio.http.Request
import zio.stream.{Stream, ZSink, ZStream}
import zio.{RIO, Task, ZIO}
import zio.stream.ZSink
import zio.stream.ZStream

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] {
override val streams: capabilities.Streams[ZioStreams] = ZioStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = {
override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] =
toRaw(serverRequest, zStream(serverRequest), bodyType, maxBytes)

def asByteArray: Task[Array[Byte]] =
(toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).runCollect.map(_.toArray)
private def toRaw[RAW](
serverRequest: ServerRequest,
stream: ZStream[Any, Throwable, Byte],
bodyType: RawBodyType[RAW],
maxBytes: Option[Long]
): Task[RawValue[RAW]] = {
def asByteArray: Task[Array[Byte]] = maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream).runCollect.map(_.toArray)

bodyType match {
case RawBodyType.StringBody(defaultCharset) => asByteArray.map(new String(_, defaultCharset)).map(RawValue(_))
case RawBodyType.ByteArrayBody => asByteArray.map(RawValue(_))
case RawBodyType.ByteBufferBody => asByteArray.map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_))
case RawBodyType.InputStreamBody => asByteArray.map(new ByteArrayInputStream(_)).map(RawValue(_))
case RawBodyType.InputStreamRangeBody =>
asByteArray.map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
asByteArray.map(bytes => InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_))
case RawBodyType.FileBody =>
for {
file <- serverOptions.createFile(serverRequest)
_ <- (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).run(ZSink.fromFile(file)).map(_ => ())
_ <- limitedZStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).unit
} yield RawValue(FileRange(file), Seq(FileRange(file)))
case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported"))
case m: RawBodyType.MultipartBody => handleMultipartBody(serverRequest, m)
}
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = {
val inputStream = stream(serverRequest)
maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream]
private def handleMultipartBody[RAW](serverRequest: ServerRequest, bodyType: RawBodyType.MultipartBody): Task[RawValue[RAW]] =
zRequest(serverRequest).body.asMultipartFormStream
.flatMap(streamingForm =>
streamingForm.fields
.flatMap(field => ZStream.fromIterable(bodyType.partType(field.name).map((field, _))))
.mapZIO { case (field, bodyType) => toRawPart(serverRequest, field, bodyType) }
.runCollect
.map(RawValue.fromParts(_).asInstanceOf[RawValue[RAW]])
)

private def toRawPart[A](serverRequest: ServerRequest, field: FormField, bodyType: RawBodyType[A]): Task[Part[A]] = {
val fieldsStream = field match {
case StreamingBinary(_, _, _, _, s) => s
case _ => ZStream.fromIterableZIO(field.asChunk)
}
toRaw(serverRequest, fieldsStream, bodyType, None)
.map(raw =>
Part(
field.name,
raw.value,
otherDispositionParams = field.filename.map(name => Map(FileNameDispositionParam -> name)).getOrElse(Map.empty)
).contentType(field.contentType.toString)
)
}

override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream =
limitedZStream(serverRequest, maxBytes).asInstanceOf[streams.BinaryStream]

private def zRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]

private def limitedZStream(serverRequest: ServerRequest, maxBytes: Option[Long]) = {
val stream = zStream(serverRequest)
maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream)
}

private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] =
zioHttpRequest(serverRequest).body.asStream
private def zStream(serverRequest: ServerRequest) = zRequest(serverRequest).body.asStream

private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request]
}

0 comments on commit 2090a02

Please sign in to comment.