diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index fe9e503324..230328f725 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -2,12 +2,21 @@ package sttp.tapir.server.ziohttp import sttp.capabilities import sttp.capabilities.zio.ZioStreams +import sttp.model.Part +import sttp.model.Part.FileNameDispositionParam +import sttp.tapir.FileRange +import sttp.tapir.InputStreamRange +import sttp.tapir.RawBodyType import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType} +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.server.interpreter.RequestBody +import zio.RIO +import zio.Task +import zio.http.FormField +import zio.http.FormField.StreamingBinary import zio.http.Request -import zio.stream.{Stream, ZSink, ZStream} -import zio.{RIO, Task, ZIO} +import zio.stream.ZSink +import zio.stream.ZStream import java.io.ByteArrayInputStream import java.nio.ByteBuffer @@ -15,10 +24,16 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = + toRaw(serverRequest, zStream(serverRequest), bodyType, maxBytes) - def asByteArray: Task[Array[Byte]] = - (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).runCollect.map(_.toArray) + private def toRaw[RAW]( + serverRequest: ServerRequest, + stream: ZStream[Any, Throwable, Byte], + bodyType: RawBodyType[RAW], + maxBytes: Option[Long] + ): Task[RawValue[RAW]] = { + def asByteArray: Task[Array[Byte]] = maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream).runCollect.map(_.toArray) bodyType match { case RawBodyType.StringBody(defaultCharset) => asByteArray.map(new String(_, defaultCharset)).map(RawValue(_)) @@ -26,23 +41,51 @@ class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends Requ case RawBodyType.ByteBufferBody => asByteArray.map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) case RawBodyType.InputStreamBody => asByteArray.map(new ByteArrayInputStream(_)).map(RawValue(_)) case RawBodyType.InputStreamRangeBody => - asByteArray.map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) + asByteArray.map(bytes => InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) case RawBodyType.FileBody => for { file <- serverOptions.createFile(serverRequest) - _ <- (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).run(ZSink.fromFile(file)).map(_ => ()) + _ <- limitedZStream(serverRequest, maxBytes).run(ZSink.fromFile(file)).unit } yield RawValue(FileRange(file), Seq(FileRange(file))) - case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) + case m: RawBodyType.MultipartBody => handleMultipartBody(serverRequest, m) } } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val inputStream = stream(serverRequest) - maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream] + private def handleMultipartBody[RAW](serverRequest: ServerRequest, bodyType: RawBodyType.MultipartBody): Task[RawValue[RAW]] = + zRequest(serverRequest).body.asMultipartFormStream + .flatMap(streamingForm => + streamingForm.fields + .flatMap(field => ZStream.fromIterable(bodyType.partType(field.name).map((field, _)))) + .mapZIO { case (field, bodyType) => toRawPart(serverRequest, field, bodyType) } + .runCollect + .map(RawValue.fromParts(_).asInstanceOf[RawValue[RAW]]) + ) + + private def toRawPart[A](serverRequest: ServerRequest, field: FormField, bodyType: RawBodyType[A]): Task[Part[A]] = { + val fieldsStream = field match { + case StreamingBinary(_, _, _, _, s) => s + case _ => ZStream.fromIterableZIO(field.asChunk) + } + toRaw(serverRequest, fieldsStream, bodyType, None) + .map(raw => + Part( + field.name, + raw.value, + otherDispositionParams = field.filename.map(name => Map(FileNameDispositionParam -> name)).getOrElse(Map.empty) + ).contentType(field.contentType.toString) + ) + } + + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + limitedZStream(serverRequest, maxBytes).asInstanceOf[streams.BinaryStream] + + private def zRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] + + private def limitedZStream(serverRequest: ServerRequest, maxBytes: Option[Long]) = { + val stream = zStream(serverRequest) + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) } - private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] = - zioHttpRequest(serverRequest).body.asStream + private def zStream(serverRequest: ServerRequest) = zRequest(serverRequest).body.asStream - private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] }