diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala index 04455169a0..9e68e07a1b 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala @@ -55,6 +55,7 @@ object NettyResponse { onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) Response(status, headers, Body.empty) } else { + val contentType = headers.get(Header.ContentType) val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status) ctx .pipeline() @@ -64,7 +65,11 @@ object NettyResponse { responseHandler, ): Unit - val data = NettyBody.fromAsync(callback => responseHandler.connect(callback), knownContentLength) + val data = NettyBody.fromAsync( + callback => responseHandler.connect(callback), + knownContentLength, + contentType.map(_.renderedValue), + ) Response(status, headers, data) } } diff --git a/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala b/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala index ad4acef187..5b0e097866 100644 --- a/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala @@ -4,12 +4,14 @@ import zio._ import zio.test.TestAspect.withLiveClock import zio.test.{Spec, TestEnvironment, assertTrue} -import zio.stream.{ZStream, ZStreamAspect} +import zio.stream.{ZPipeline, ZStream, ZStreamAspect} import zio.http.ZClient.Config import zio.http._ import zio.http.internal.HttpRunnableSpec +import zio.http.multipart.mixed.MultipartMixed import zio.http.netty.NettyConfig.LeakDetectionLevel +import zio.http.netty.NettyStreamBodySpec.app object NettyStreamBodySpec extends HttpRunnableSpec { @@ -101,6 +103,82 @@ object NettyStreamBodySpec extends HttpRunnableSpec { ) } }, + test("properly decodes body's boundary") { + def trackablePart(content: String): ZIO[Any, Nothing, (MultipartMixed.Part, Promise[Nothing, Boolean])] = { + zio.Promise.make[Nothing, Boolean].map { p => + MultipartMixed.Part( + Headers(Header.ContentType(MediaType.text.`plain`)), + ZStream(content) + .via(ZPipeline.utf8Encode) + .ensuring(p.succeed(true)), + ) -> + p + } + } + def trackableMultipartMixed( + b: Boundary, + )(partsContents: String*): ZIO[Any, Nothing, (MultipartMixed, Seq[Promise[Nothing, Boolean]])] = { + ZIO + .foreach(partsContents)(trackablePart) + .map { tps => + val (parts, promisises) = tps.unzip + val mpm = MultipartMixed.fromParts(ZStream.fromIterable(parts), b, 1) + (mpm, promisises) + } + } + + def serve(resp: Response): ZIO[Any, Throwable, RuntimeFlags] = { + val app = Routes(Method.GET / "it" -> handler(resp)) + for { + portPromise <- Promise.make[Throwable, Int] + _ <- Server + .install(app) + .intoPromise(portPromise) + .zipRight(ZIO.never) + .provide( + ZLayer.succeed(NettyConfig.defaultWithFastShutdown.leakDetection(LeakDetectionLevel.PARANOID)), + ZLayer.succeed(Server.Config.default.onAnyOpenPort), + Server.customized, + ) + .fork + port <- portPromise.await + } yield port + } + + for { + mpmAndPromises <- trackableMultipartMixed(Boundary("this_is_a_boundary"))( + "this is the boring part 1", + "and this is the boring part two", + ) + (mpm, promises) = mpmAndPromises + resp = Response(body = + Body.fromStreamChunked(mpm.source).contentType(MediaType.multipart.`mixed`, mpm.boundary), + ) + .addHeader(Header.ContentType(MediaType.multipart.`mixed`, Some(mpm.boundary))) + port <- serve(resp) + client <- ZIO.service[Client] + req = Request.get(s"http://localhost:$port/it") + actualResp <- client(req) + actualMpm <- actualResp.body.asMultipartMixed + partsResults <- actualMpm.parts.zipWithIndex.mapZIO { case (part, idx) => + val pr = promises(idx.toInt) + // todo: due to server side buffering can't really expect the promises to be uncompleted BEFORE pulling on the client side + part.toBody.asString <*> + pr.isDone + }.runCollect + } yield { + zio.test.assertTrue { + actualResp.headers(Header.ContentType) == resp.headers(Header.ContentType) && + actualResp.body.boundary == Some(mpm.boundary) && + actualMpm.boundary == mpm.boundary && + partsResults == Chunk( + // todo: due to server side buffering can't really expect the promises to be uncompleted BEFORE pulling on the client side + ("this is the boring part 1", true), + ("and this is the boring part two", true), + ) + } + } + }, ).provide( singleConnectionClient, Scope.default,