diff --git a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaCompressor.scala b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaCompressor.scala new file mode 100644 index 0000000000..f00accde65 --- /dev/null +++ b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaCompressor.scala @@ -0,0 +1,32 @@ +package sttp.client4.akkahttp + +import akka.util.ByteString +import akka.stream.scaladsl.Compression +import sttp.capabilities.akka.AkkaStreams +import akka.stream.scaladsl.Source +import sttp.client4._ +import sttp.client4.compression.DeflateDefaultCompressor +import sttp.client4.compression.GZipDefaultCompressor +import sttp.client4.compression.Compressor +import akka.stream.scaladsl.StreamConverters +import akka.stream.scaladsl.FileIO + +trait AkkaCompressor extends Compressor[AkkaStreams] { + override abstract def apply[R2 <: AkkaStreams](body: GenericRequestBody[R2]): GenericRequestBody[AkkaStreams] = + body match { + case InputStreamBody(b, _) => StreamBody(AkkaStreams)(compressStream(StreamConverters.fromInputStream(() => b))) + case StreamBody(b) => StreamBody(AkkaStreams)(compressStream(b.asInstanceOf[Source[ByteString, Any]])) + case FileBody(f, _) => StreamBody(AkkaStreams)(compressStream(FileIO.fromPath(f.toPath))) + case _ => super.apply(body) + } + + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] +} + +object GZipAkkaCompressor extends GZipDefaultCompressor[AkkaStreams] with AkkaCompressor { + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.gzip) +} + +object DeflateAkkaCompressor extends DeflateDefaultCompressor[AkkaStreams] with AkkaCompressor { + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.deflate) +} diff --git a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaHttpBackend.scala b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaHttpBackend.scala index 20af82596c..fb109bcc7d 100644 --- a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaHttpBackend.scala +++ b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/AkkaHttpBackend.scala @@ -1,10 +1,8 @@ package sttp.client4.akkahttp -import java.io.UnsupportedEncodingException import akka.{Done, NotUsed} import akka.actor.{ActorSystem, CoordinatedShutdown} import akka.event.LoggingAdapter -import akka.http.scaladsl.coding.Coders import akka.http.scaladsl.model.headers.{BasicHttpCredentials, HttpEncoding, HttpEncodings} import akka.http.scaladsl.model.ws.{InvalidUpgradeResponse, Message, ValidUpgrade, WebSocketRequest} import akka.http.scaladsl.model.{StatusCode => _, _} @@ -14,13 +12,13 @@ import akka.stream.Materializer import akka.stream.scaladsl.{Flow, Sink} import sttp.capabilities.akka.AkkaStreams import sttp.capabilities.{Effect, WebSockets} -import sttp.client4 -import sttp.client4.akkahttp.AkkaHttpBackend.EncodingHandler import sttp.client4.testing.WebSocketStreamBackendStub import sttp.client4._ import sttp.client4.wrappers.FollowRedirectsBackend import sttp.model.{ResponseMetadata, StatusCode} import sttp.monad.{FutureMonad, MonadError} +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor import scala.concurrent.{ExecutionContext, Future, Promise} @@ -34,7 +32,7 @@ class AkkaHttpBackend private ( customizeRequest: HttpRequest => HttpRequest, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse, - customEncodingHandler: EncodingHandler + compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] ) extends WebSocketStreamBackend[Future, AkkaStreams] { type R = AkkaStreams with WebSockets with Effect[Future] @@ -52,7 +50,7 @@ class AkkaHttpBackend private ( private def sendRegular[T](r: GenericRequest[T, R]): Future[Response[T]] = Future - .fromTry(ToAkka.request(r).flatMap(BodyToAkka(r, r.body, _))) + .fromTry(ToAkka.request(r).flatMap(BodyToAkka(r, _, compressionHandlers.compressors))) .map(customizeRequest) .flatMap(request => http @@ -133,23 +131,21 @@ class AkkaHttpBackend private ( val body = bodyFromAkka( r.response, responseMetadata, - wsFlow.map(Right(_)).getOrElse(Left(decodeAkkaResponse(hr, r.autoDecompressionDisabled))) + wsFlow.map(Right(_)).getOrElse(Left(decodeAkkaResponse(hr, r.autoDecompressionEnabled))) ) - body.map(client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata)) + body.map(sttp.client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata)) } // http://doc.akka.io/docs/akka-http/10.0.7/scala/http/common/de-coding.html - private def decodeAkkaResponse(response: HttpResponse, disableAutoDecompression: Boolean): HttpResponse = - if (!response.status.allowsEntity() || disableAutoDecompression) response - else customEncodingHandler.orElse(EncodingHandler(standardEncoding)).apply(response -> response.encoding) - - private def standardEncoding: (HttpResponse, HttpEncoding) => HttpResponse = { - case (body, HttpEncodings.gzip) => Coders.Gzip.decodeMessage(body) - case (body, HttpEncodings.deflate) => Coders.Deflate.decodeMessage(body) - case (body, HttpEncodings.identity) => Coders.NoCoding.decodeMessage(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } + private def decodeAkkaResponse(response: HttpResponse, autoDecompressionEnabled: Boolean): HttpResponse = + if (!response.status.allowsEntity() || !autoDecompressionEnabled) response + else + response.encoding match { + case HttpEncodings.identity => response + case encoding: HttpEncoding => + Decompressor.decompressIfPossible(response, encoding.value, compressionHandlers.decompressors) + } private def adjustExceptions[T](request: GenericRequest[_, _])(t: => Future[T]): Future[T] = SttpClientException.adjustExceptions(monad)(t)(FromAkka.exception(request, _)) @@ -165,12 +161,11 @@ class AkkaHttpBackend private ( } object AkkaHttpBackend { - type EncodingHandler = PartialFunction[(HttpResponse, HttpEncoding), HttpResponse] - object EncodingHandler { - def apply(f: (HttpResponse, HttpEncoding) => HttpResponse): EncodingHandler = { case (body, encoding) => - f(body, encoding) - } - } + val DefaultCompressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = + CompressionHandlers( + List(GZipAkkaCompressor, DeflateAkkaCompressor), + List(GZipAkkaDecompressor, DeflateAkkaDecompressor) + ) private def make( actorSystem: ActorSystem, @@ -182,7 +177,7 @@ object AkkaHttpBackend { customizeRequest: HttpRequest => HttpRequest, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] ): WebSocketStreamBackend[Future, AkkaStreams] = FollowRedirectsBackend( new AkkaHttpBackend( @@ -195,7 +190,7 @@ object AkkaHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) ) @@ -211,7 +206,7 @@ object AkkaHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, AkkaStreams] = { @@ -227,7 +222,7 @@ object AkkaHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) } @@ -246,7 +241,7 @@ object AkkaHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, AkkaStreams] = @@ -258,7 +253,7 @@ object AkkaHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) /** @param actorSystem @@ -275,7 +270,7 @@ object AkkaHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, AkkaStreams] = @@ -289,7 +284,7 @@ object AkkaHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) /** Create a stub backend for testing, which uses the [[Future]] response wrapper, and doesn't support streaming. diff --git a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/BodyToAkka.scala b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/BodyToAkka.scala index e3e1ae499c..2a0edef50c 100644 --- a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/BodyToAkka.scala +++ b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/BodyToAkka.scala @@ -12,18 +12,18 @@ import akka.http.scaladsl.model.{ import akka.stream.scaladsl.{Source, StreamConverters} import akka.util.ByteString import sttp.capabilities.akka.AkkaStreams -import sttp.client4.internal.throwNestedMultipartNotAllowed import sttp.client4._ -import sttp.model.{HeaderNames, Part} +import sttp.model.Part import scala.collection.immutable.Seq import scala.util.{Failure, Success, Try} +import sttp.client4.compression.Compressor private[akkahttp] object BodyToAkka { def apply[R]( r: GenericRequest[_, R], - body: GenericRequestBody[R], - ar: HttpRequest + ar: HttpRequest, + compressors: List[Compressor[R]] ): Try[HttpRequest] = { def ctWithCharset(ct: ContentType, charset: String) = HttpCharsets @@ -31,7 +31,7 @@ private[akkahttp] object BodyToAkka { .map(hc => ContentType.apply(ct.mediaType, () => hc)) .getOrElse(ct) - def contentLength = r.headers.find(_.is(HeaderNames.ContentLength)).flatMap(h => Try(h.value.toLong).toOption) + val (body, contentLength) = Compressor.compressIfNeeded(r, compressors) def toBodyPart(mp: Part[BodyPart[_]]): Try[AkkaMultipart.FormData.BodyPart] = { def streamPartEntity(contentType: ContentType, s: AkkaStreams.BinaryStream) = diff --git a/akka-http-backend/src/main/scala/sttp/client4/akkahttp/akkaDecompressors.scala b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/akkaDecompressors.scala new file mode 100644 index 0000000000..a88eb8b13d --- /dev/null +++ b/akka-http-backend/src/main/scala/sttp/client4/akkahttp/akkaDecompressors.scala @@ -0,0 +1,16 @@ +package sttp.client4.akkahttp + +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import akka.http.scaladsl.coding.Coders +import akka.http.scaladsl.model.HttpResponse + +object GZipAkkaDecompressor extends Decompressor[HttpResponse] { + override val encoding: String = Encodings.Gzip + override def apply(body: HttpResponse): HttpResponse = Coders.Gzip.decodeMessage(body) +} + +object DeflateAkkaDecompressor extends Decompressor[HttpResponse] { + override val encoding: String = Encodings.Deflate + override def apply(body: HttpResponse): HttpResponse = Coders.Deflate.decodeMessage(body) +} diff --git a/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala b/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala index 2a793702ee..3796966974 100644 --- a/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala +++ b/armeria-backend/fs2/src/main/scala/sttp/client4/armeria/fs2/ArmeriaFs2Backend.scala @@ -12,9 +12,11 @@ import sttp.capabilities.fs2.Fs2Streams import sttp.client4.armeria.ArmeriaWebClient.newClient import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage} import sttp.client4.impl.cats.CatsMonadAsyncError -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, StreamBackend} import sttp.monad.MonadAsyncError +import sttp.client4.compression.Compressor +import sttp.client4.impl.fs2.DeflateFs2Compressor +import sttp.client4.impl.fs2.GZipFs2Compressor private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFactory: Boolean, dispatcher: Dispatcher[F]) extends AbstractArmeriaBackend[F, Fs2Streams[F]](client, closeFactory, new CatsMonadAsyncError) { @@ -41,6 +43,9 @@ private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFacto }, dispatcher ) + + override protected def compressors: List[Compressor[R]] = + List(new GZipFs2Compressor[F, R](), new DeflateFs2Compressor[F, R]()) } object ArmeriaFs2Backend { diff --git a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala index 5e500fdf17..815c601c36 100644 --- a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala +++ b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala @@ -39,6 +39,7 @@ import sttp.client4.internal.toByteArray import sttp.model._ import sttp.monad.syntax._ import sttp.monad.{Canceler, MonadAsyncError} +import sttp.client4.compression.Compressor abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( client: WebClient = WebClient.of(), @@ -54,6 +55,8 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData] + protected def compressors: List[Compressor[R]] = Compressor.default[R] + override def send[T](request: GenericRequest[T, R]): F[Response[T]] = monad.suspend(adjustExceptions(request)(execute(request))) @@ -87,7 +90,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( } finally captor.close() } - private def requestToArmeria(request: GenericRequest[_, Nothing]): WebClientRequestPreparation = { + private def requestToArmeria(request: GenericRequest[_, R]): WebClientRequestPreparation = { val requestPreparation = client .prepare() .disablePathParams() @@ -102,19 +105,24 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( requestPreparation.responseTimeoutMillis(Long.MaxValue) } + val (body, contentLength) = Compressor.compressIfNeeded(request, compressors) + var customContentType: Option[ArmeriaMediaType] = None request.headers.foreach { header => - if (header.name.equalsIgnoreCase(HeaderNames.ContentType)) { + if (header.is(HeaderNames.ContentType)) { // A Content-Type will be set with the body content customContentType = Some(ArmeriaMediaType.parse(header.value)) - } else { - requestPreparation.header(header.name, header.value) + } else if (!header.is(HeaderNames.ContentLength)) { + val _ = requestPreparation.header(header.name, header.value) } } + contentLength.foreach { cl => + requestPreparation.header(HeaderNames.ContentLength, cl.toString) + } val contentType = customContentType.getOrElse(ArmeriaMediaType.parse(request.body.defaultContentType.toString())) - request.body match { + body match { case NoBody => requestPreparation case StringBody(s, encoding, _) => val charset = diff --git a/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala b/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala index bba0051620..3abfd526a6 100644 --- a/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala +++ b/armeria-backend/src/main/scala/sttp/client4/armeria/future/ArmeriaFutureBackend.scala @@ -8,7 +8,7 @@ import sttp.client4.armeria.ArmeriaWebClient.newClient import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage} import sttp.client4.internal.NoStreams import sttp.client4.wrappers.FollowRedirectsBackend -import sttp.client4.{wrappers, Backend, BackendOptions} +import sttp.client4.{Backend, BackendOptions} import sttp.monad.{FutureMonad, MonadAsyncError} import scala.concurrent.ExecutionContext.Implicits.global @@ -50,5 +50,5 @@ object ArmeriaFutureBackend { apply(newClient(), closeFactory = false) private def apply(client: WebClient, closeFactory: Boolean): Backend[Future] = - wrappers.FollowRedirectsBackend(new ArmeriaFutureBackend(client, closeFactory)) + FollowRedirectsBackend(new ArmeriaFutureBackend(client, closeFactory)) } diff --git a/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala b/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala index 8d863b7875..ea4806f2c5 100644 --- a/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala +++ b/armeria-backend/zio/src/main/scala/sttp/client4/armeria/zio/ArmeriaZioBackend.scala @@ -13,10 +13,12 @@ import sttp.capabilities.zio.ZioStreams import sttp.client4.armeria.ArmeriaWebClient.newClient import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage} import sttp.client4.impl.zio.RIOMonadAsyncError -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, StreamBackend} import sttp.monad.MonadAsyncError import zio.stream.Stream +import sttp.client4.compression.Compressor +import sttp.client4.impl.zio.GZipZioCompressor +import sttp.client4.impl.zio.DeflateZioCompressor private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient, closeFactory: Boolean) extends AbstractArmeriaBackend[Task, ZioStreams](client, closeFactory, new RIOMonadAsyncError[Any]) { @@ -40,6 +42,8 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient, .run(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher) .getOrThrowFiberFailure() } + + override protected def compressors: List[Compressor[R]] = List(GZipZioCompressor, DeflateZioCompressor) } object ArmeriaZioBackend { diff --git a/build.sbt b/build.sbt index d12dd37f83..20d08fd8cf 100644 --- a/build.sbt +++ b/build.sbt @@ -153,6 +153,7 @@ val pekkoStreams = "org.apache.pekko" %% "pekko-stream" % pekkoStreamVersion val scalaTest = libraryDependencies ++= Seq("freespec", "funsuite", "flatspec", "wordspec", "shouldmatchers").map(m => "org.scalatest" %%% s"scalatest-$m" % "3.2.19" % Test ) +val scalaTestPlusScalaCheck = libraryDependencies += "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test val zio1Version = "1.0.18" val zio2Version = "2.1.14" @@ -318,7 +319,8 @@ lazy val core = (projectMatrix in file("core")) "com.softwaremill.sttp.shared" %%% "core" % sttpSharedVersion, "com.softwaremill.sttp.shared" %%% "ws" % sttpSharedVersion ), - scalaTest + scalaTest, + scalaTestPlusScalaCheck ) .settings(testServerSettings) .jvmPlatform( diff --git a/core/src/main/scala/sttp/client4/RequestOptions.scala b/core/src/main/scala/sttp/client4/RequestOptions.scala index 78e3b46769..b5031aa50d 100644 --- a/core/src/main/scala/sttp/client4/RequestOptions.scala +++ b/core/src/main/scala/sttp/client4/RequestOptions.scala @@ -4,13 +4,26 @@ import scala.concurrent.duration.Duration import sttp.model.HttpVersion import sttp.client4.logging.LoggingOptions -/** Options for a [[Request]]. The defaults can be found on [[emptyRequest]]. */ +/** Options for a [[Request]]. The defaults can be found on [[emptyRequest]]. + * + * @param decompressResponseBody + * Should the response body be decompressed, if a `Content-Encoding` header is present. By default, backends support + * [[sttp.model.Encodings.Gzip]] and [[sttp.model.Encodings.Deflate]] encodings, but others might available as well; + * refer to the backend documentation for details. If an encoding is not supported, an exception is thrown / a failed + * effet returned, when sending the request. + * @param compressRequestBody + * Should the request body be compressed, and if so, with which encoding. By default, backends support + * [[sttp.model.Encodings.Gzip]] and [[sttp.model.Encodings.Deflate]] encodings, but others might available as well; + * refer to the backend documentation for details. If an encoding is not supported, an exception is thrown / a failed + * effet returned, when sending the request. + */ case class RequestOptions( followRedirects: Boolean, readTimeout: Duration, maxRedirects: Int, redirectToGet: Boolean, - disableAutoDecompression: Boolean, + decompressResponseBody: Boolean, + compressRequestBody: Option[String], httpVersion: Option[HttpVersion], loggingOptions: LoggingOptions ) diff --git a/core/src/main/scala/sttp/client4/SttpApi.scala b/core/src/main/scala/sttp/client4/SttpApi.scala index 368242e0a3..b6a467cc89 100644 --- a/core/src/main/scala/sttp/client4/SttpApi.scala +++ b/core/src/main/scala/sttp/client4/SttpApi.scala @@ -31,7 +31,8 @@ trait SttpApi extends SttpExtensions with UriInterpolator { DefaultReadTimeout, FollowRedirectsBackend.MaxRedirects, redirectToGet = false, - disableAutoDecompression = false, + decompressResponseBody = true, + compressRequestBody = None, httpVersion = None, loggingOptions = LoggingOptions() ), diff --git a/core/src/main/scala/sttp/client4/compression/CompressionHandlers.scala b/core/src/main/scala/sttp/client4/compression/CompressionHandlers.scala new file mode 100644 index 0000000000..fd365c1164 --- /dev/null +++ b/core/src/main/scala/sttp/client4/compression/CompressionHandlers.scala @@ -0,0 +1,24 @@ +package sttp.client4.compression + +/** Defines the compressors that might be used to compress request bodies, and decompressors that might be used to + * decompress response bodies. + * + * @tparam R + * The capabilities that the bodyies (both to compress & compressed) might use (e.g. streams). + * @tparam B + * The type of the raw body (as used by the backend) that is decompressed. + * @see + * [[sttp.client4.RequestOptions.decompressResponseBody]] + * @see + * [[sttp.client4.RequestOptions.compressRequestBody]] + */ +case class CompressionHandlers[-R, B]( + compressors: List[Compressor[R]], + decompressors: List[Decompressor[B]] +) { + def addCompressor[R2 <: R](compressors: Compressor[R2]): CompressionHandlers[R2, B] = + copy(compressors = this.compressors :+ compressors) + + def addDecompressor(decompressors: Decompressor[B]): CompressionHandlers[R, B] = + copy(decompressors = this.decompressors :+ decompressors) +} diff --git a/core/src/main/scala/sttp/client4/compression/Compressor.scala b/core/src/main/scala/sttp/client4/compression/Compressor.scala new file mode 100644 index 0000000000..a5e5b59c87 --- /dev/null +++ b/core/src/main/scala/sttp/client4/compression/Compressor.scala @@ -0,0 +1,64 @@ +package sttp.client4.compression + +import sttp.client4._ +import java.nio.ByteBuffer + +/** Allows compressing bodies, using the supported encoding. The compressed bodies might use `R` capabilities (e.g. + * streaming). + */ +trait Compressor[-R] { + def encoding: String + def apply[R2 <: R](body: GenericRequestBody[R2]): GenericRequestBody[R] +} + +object Compressor extends CompressorExtensions { + + /** Compress the request body if needed, using the given compressors. + * @return + * The optionally compressed body (if requested via request options), or the original body; and the content lenght, + * if known, of the uncompressed/compressed body. + */ + def compressIfNeeded[T, R]( + request: GenericRequest[T, R], + compressors: List[Compressor[R]] + ): (GenericRequestBody[R], Option[Long]) = + request.options.compressRequestBody match { + case Some(encoding) => + val compressedBody = compressors.find(_.encoding.equalsIgnoreCase(encoding)) match { + case Some(compressor) => compressor(request.body) + case None => throw new IllegalArgumentException(s"Unsupported encoding: $encoding") + } + + val contentLength = calculateContentLength(compressedBody) + (compressedBody, contentLength) + + case None => (request.body, request.contentLength) + } + + private def calculateContentLength[R](body: GenericRequestBody[R]): Option[Long] = body match { + case NoBody => None + case StringBody(b, e, _) => Some(b.getBytes(e).length.toLong) + case ByteArrayBody(b, _) => Some(b.length.toLong) + case ByteBufferBody(b, _) => None + case InputStreamBody(b, _) => None + case FileBody(f, _) => Some(f.length()) + case StreamBody(_) => None + case MultipartStreamBody(parts) => None + case BasicMultipartBody(parts) => None + } + + private[compression] def compressingMultipartBodiesNotSupported: Nothing = + throw new IllegalArgumentException("Multipart bodies cannot be compressed") + + private[compression] def streamsNotSupported: Nothing = + throw new IllegalArgumentException("Streams are not supported") + + private[compression] def byteBufferToArray(inputBuffer: ByteBuffer): Array[Byte] = + if (inputBuffer.hasArray()) { + inputBuffer.array() + } else { + val inputBytes = new Array[Byte](inputBuffer.remaining()) + inputBuffer.get(inputBytes) + inputBytes + } +} diff --git a/core/src/main/scala/sttp/client4/compression/Decompressor.scala b/core/src/main/scala/sttp/client4/compression/Decompressor.scala new file mode 100644 index 0000000000..0a08844bbb --- /dev/null +++ b/core/src/main/scala/sttp/client4/compression/Decompressor.scala @@ -0,0 +1,17 @@ +package sttp.client4.compression + +import java.io.UnsupportedEncodingException + +/** Allows decompressing bodies, using the supported encoding. */ +trait Decompressor[B] { + def encoding: String + def apply(body: B): B +} + +object Decompressor extends DecompressorExtensions { + def decompressIfPossible[B](b: B, encoding: String, decompressors: List[Decompressor[B]]): B = + decompressors.find(_.encoding.equalsIgnoreCase(encoding)) match { + case Some(decompressor) => decompressor(b) + case None => throw new UnsupportedEncodingException(s"Unsupported encoding: $encoding") + } +} diff --git a/core/src/main/scala/sttp/client4/internal/BodyFromResponseAs.scala b/core/src/main/scala/sttp/client4/internal/BodyFromResponseAs.scala index 8b570a8cc6..2c908ec697 100644 --- a/core/src/main/scala/sttp/client4/internal/BodyFromResponseAs.scala +++ b/core/src/main/scala/sttp/client4/internal/BodyFromResponseAs.scala @@ -26,7 +26,7 @@ abstract class BodyFromResponseAs[F[_], RegularResponse, WSResponse, Stream](imp m.eval(g(result, meta)).map((_, replayableBody)) } - case (rfm: ResponseAsFromMetadata[T, _], _) => doApply(rfm(meta), meta, response) + case (rfm: ResponseAsFromMetadata[T, _] @unchecked, _) => doApply(rfm(meta), meta, response) case (ResponseAsBoth(l, r), _) => doApply(l, meta, response).flatMap { diff --git a/core/src/main/scala/sttp/client4/request.scala b/core/src/main/scala/sttp/client4/request.scala index df7730ca43..472fa048b7 100644 --- a/core/src/main/scala/sttp/client4/request.scala +++ b/core/src/main/scala/sttp/client4/request.scala @@ -25,7 +25,7 @@ import sttp.attributes.AttributeMap * ability to send and receive streaming bodies) or [[sttp.capabilities.WebSockets]] (the ability to handle websocket * requests). */ -trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata { +sealed trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata { def body: GenericRequestBody[R] def response: ResponseAsDelegate[T, R] diff --git a/core/src/main/scala/sttp/client4/requestBuilder.scala b/core/src/main/scala/sttp/client4/requestBuilder.scala index 753a227e10..90196f0e31 100644 --- a/core/src/main/scala/sttp/client4/requestBuilder.scala +++ b/core/src/main/scala/sttp/client4/requestBuilder.scala @@ -301,15 +301,36 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R] */ def redirectToGet(r: Boolean): PR = withOptions(options.copy(redirectToGet = r)) - /** Disables auto-decompression of response bodies which are received with supported `Content-Encoding headers. */ - def disableAutoDecompression: PR = withOptions(options.copy(disableAutoDecompression = true)) + /** Disables auto-decompression of response bodies which are received with supported `Content-Encoding` headers. + * + * @see + * [[RequestOptions.decompressResponseBody]] + */ + def disableAutoDecompression: PR = withOptions(options.copy(decompressResponseBody = false)) - /** True iff auto-decompression is disabled. + /** True iff auto-decompression is enabled (which is the default). * * @see * disableAutoDecompression + * @see + * [[RequestOptions.decompressResponseBody]] + */ + def autoDecompressionEnabled: Boolean = options.decompressResponseBody + + /** Compress the request body with the given encoding. + * + * The backend must support the encoding, otherwise an exception is thrown / a failed effect is returned. All + * backends support the [[sttp.model.Encodings.Gzip]] and [[sttp.model.Encodings.Deflate]] encodings. + * + * Note that the server might not support compressed bodies. By default request bodies are not compressed. + * + * @see + * [[sttp.model.Encodings]] + * @see + * [[RequestOptions.compressRequestBody]] */ - def autoDecompressionDisabled: Boolean = options.disableAutoDecompression + def compressBody(encoding: String): PR = + withOptions(options.copy(compressRequestBody = Some(encoding))).header(HeaderNames.ContentEncoding, encoding) /** Set the HTTP version with which this request should be sent. Supported only in a few backends. */ def httpVersion(version: HttpVersion): PR = withOptions(options.copy(httpVersion = Some(version))) diff --git a/core/src/main/scala/sttp/client4/testing/package.scala b/core/src/main/scala/sttp/client4/testing/package.scala index bf727143cf..c562714f01 100644 --- a/core/src/main/scala/sttp/client4/testing/package.scala +++ b/core/src/main/scala/sttp/client4/testing/package.scala @@ -15,7 +15,7 @@ package object testing { case ByteArrayBody(b, _) => new String(b) case ByteBufferBody(b, _) => new String(b.array()) case InputStreamBody(b, _) => new String(toByteArray(b)) - case FileBody(f, _) => f.readAsString + case FileBody(f, _) => f.readAsString() case StreamBody(_) => throw new IllegalArgumentException("The body of this request is a stream, cannot convert to String") case _: MultipartBody[_] => @@ -32,7 +32,7 @@ package object testing { case ByteArrayBody(b, _) => b case ByteBufferBody(b, _) => b.array() case InputStreamBody(b, _) => toByteArray(b) - case FileBody(f, _) => f.readAsByteArray + case FileBody(f, _) => f.readAsByteArray() case StreamBody(_) => throw new IllegalArgumentException("The body of this request is a stream, cannot convert to String") case _: MultipartBody[_] => diff --git a/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..9831ced732 --- /dev/null +++ b/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = Nil +} diff --git a/core/src/main/scalajs/sttp/client4/compression/DecompressorExtensions.scala b/core/src/main/scalajs/sttp/client4/compression/DecompressorExtensions.scala new file mode 100644 index 0000000000..386c2203f1 --- /dev/null +++ b/core/src/main/scalajs/sttp/client4/compression/DecompressorExtensions.scala @@ -0,0 +1,7 @@ +package sttp.client4.compression + +import java.io.InputStream + +trait DecompressorExtensions { + def defaultInputStream: List[Decompressor[InputStream]] = Nil +} diff --git a/core/src/main/scalajs/sttp/client4/internal/SttpFileExtensions.scala b/core/src/main/scalajs/sttp/client4/internal/SttpFileExtensions.scala index 55230f161f..285e9dd9b5 100644 --- a/core/src/main/scalajs/sttp/client4/internal/SttpFileExtensions.scala +++ b/core/src/main/scalajs/sttp/client4/internal/SttpFileExtensions.scala @@ -2,13 +2,18 @@ package sttp.client4.internal import org.scalajs.dom.File +import java.io.FileInputStream +import java.io.InputStream + // wrap a DomFile trait SttpFileExtensions { self: SttpFile => def toDomFile: File = underlying.asInstanceOf[File] - def readAsString: String = throw new UnsupportedOperationException() - def readAsByteArray: Array[Byte] = throw new UnsupportedOperationException() + def readAsString(): String = throw new UnsupportedOperationException() + def readAsByteArray(): Array[Byte] = throw new UnsupportedOperationException() + def openStream(): InputStream = throw new UnsupportedOperationException() + def length(): Long = throw new UnsupportedOperationException() } trait SttpFileCompanionExtensions { diff --git a/core/src/main/scalajvm/sttp/client4/DefaultFutureBackend.scala b/core/src/main/scalajvm/sttp/client4/DefaultFutureBackend.scala index 967b78996d..a41a65778b 100644 --- a/core/src/main/scalajvm/sttp/client4/DefaultFutureBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/DefaultFutureBackend.scala @@ -13,7 +13,7 @@ object DefaultFutureBackend { def apply( options: BackendOptions = BackendOptions.Default )(implicit ec: ExecutionContext = ExecutionContext.global): WebSocketBackend[Future] = - HttpClientFutureBackend(options, identity, PartialFunction.empty) + HttpClientFutureBackend(options) /** Create a stub backend for testing, which uses [[Future]] to represent side effects, and doesn't support streaming. * diff --git a/core/src/main/scalajvm/sttp/client4/DefaultSyncBackend.scala b/core/src/main/scalajvm/sttp/client4/DefaultSyncBackend.scala index 4e3df79b13..3c6ef568b2 100644 --- a/core/src/main/scalajvm/sttp/client4/DefaultSyncBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/DefaultSyncBackend.scala @@ -8,8 +8,7 @@ object DefaultSyncBackend { /** Creates a default synchronous backend with the given `options`, which is currently based on * [[HttpClientSyncBackend]]. */ - def apply(options: BackendOptions = BackendOptions.Default): WebSocketSyncBackend = - HttpClientSyncBackend(options, identity, PartialFunction.empty) + def apply(options: BackendOptions = BackendOptions.Default): WebSocketSyncBackend = HttpClientSyncBackend(options) /** Create a stub backend for testing. See [[WebSocketSyncBackendStub]] for details on how to configure stub * responses. diff --git a/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..a78a366806 --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = List(new GZipDefaultCompressor[R](), new DeflateDefaultCompressor[R]()) +} diff --git a/core/src/main/scalajvm/sttp/client4/compression/DecompressorExtensions.scala b/core/src/main/scalajvm/sttp/client4/compression/DecompressorExtensions.scala new file mode 100644 index 0000000000..4bde0977e2 --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/DecompressorExtensions.scala @@ -0,0 +1,8 @@ +package sttp.client4.compression + +import java.io.InputStream + +trait DecompressorExtensions { + def defaultInputStream: List[Decompressor[InputStream]] = + List(GZipInputStreamDecompressor, DeflateInputStreamDecompressor) +} diff --git a/core/src/main/scalajvm/sttp/client4/compression/GZIPCompressingInputStream.scala b/core/src/main/scalajvm/sttp/client4/compression/GZIPCompressingInputStream.scala new file mode 100644 index 0000000000..10ef537f91 --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/GZIPCompressingInputStream.scala @@ -0,0 +1,139 @@ +package sttp.client4.compression + +import java.io.{ByteArrayInputStream, IOException, InputStream} +import java.util.zip.{CRC32, Deflater} + +// based on: +// https://github.com/http4k/http4k/blob/master/core/core/src/main/kotlin/org/http4k/filter/Gzip.kt#L124 +// https://stackoverflow.com/questions/11036280/compress-an-inputstream-with-gzip +class GZIPCompressingInputStream( + source: InputStream, + compressionLevel: Int = java.util.zip.Deflater.DEFAULT_COMPRESSION +) extends InputStream { + + private object State extends Enumeration { + type State = Value + val HEADER, DATA, FINALISE, TRAILER, DONE = Value + } + + import State._ + + private val GZIP_MAGIC = 0x8b1f + private val HEADER_DATA: Array[Byte] = Array( + GZIP_MAGIC.toByte, + (GZIP_MAGIC >> 8).toByte, + Deflater.DEFLATED.toByte, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ) + private val INITIAL_BUFFER_SIZE = 8192 + + private val deflater = new Deflater(Deflater.DEFLATED, true) + deflater.setLevel(compressionLevel) + + private val crc = new CRC32() + private var trailer: ByteArrayInputStream = _ + private val header = new ByteArrayInputStream(HEADER_DATA) + + private var deflationBuffer: Array[Byte] = new Array[Byte](INITIAL_BUFFER_SIZE) + private var stage: State = HEADER + + override def read(): Int = { + val readBytes = new Array[Byte](1) + var bytesRead = 0 + while (bytesRead == 0) + bytesRead = read(readBytes, 0, 1) + if (bytesRead != -1) readBytes(0) & 0xff else -1 + } + + @throws[IOException] + override def read(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = stage match { + case HEADER => + val bytesRead = header.read(readBuffer, readOffset, readLength) + if (header.available() == 0) stage = DATA + bytesRead + + case DATA => + if (!deflater.needsInput) { + deflatePendingInput(readBuffer, readOffset, readLength) + } else { + if (deflationBuffer.length < readLength) { + deflationBuffer = new Array[Byte](readLength) + } + + val bytesRead = source.read(deflationBuffer, 0, readLength) + if (bytesRead <= 0) { + stage = FINALISE + deflater.finish() + 0 + } else { + crc.update(deflationBuffer, 0, bytesRead) + deflater.setInput(deflationBuffer, 0, bytesRead) + deflatePendingInput(readBuffer, readOffset, readLength) + } + } + + case FINALISE => + if (deflater.finished()) { + stage = TRAILER + val crcValue = crc.getValue.toInt + val totalIn = deflater.getTotalIn + trailer = createTrailer(crcValue, totalIn) + 0 + } else { + deflater.deflate(readBuffer, readOffset, readLength, Deflater.FULL_FLUSH) + } + + case TRAILER => + val bytesRead = trailer.read(readBuffer, readOffset, readLength) + if (trailer.available() == 0) stage = DONE + bytesRead + + case DONE => -1 + + case _ => throw new IllegalArgumentException(s"Invalid state: $stage") + } + + private def deflatePendingInput(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = { + var bytesCompressed = 0 + while (!deflater.needsInput && readLength - bytesCompressed > 0) + bytesCompressed += deflater.deflate( + readBuffer, + readOffset + bytesCompressed, + readLength - bytesCompressed, + Deflater.FULL_FLUSH + ) + bytesCompressed + } + + private def createTrailer(crcValue: Int, totalIn: Int): ByteArrayInputStream = + new ByteArrayInputStream( + Array( + (crcValue >> 0).toByte, + (crcValue >> 8).toByte, + (crcValue >> 16).toByte, + (crcValue >> 24).toByte, + (totalIn >> 0).toByte, + (totalIn >> 8).toByte, + (totalIn >> 16).toByte, + (totalIn >> 24).toByte + ) + ) + + override def available(): Int = if (stage == DONE) 0 else 1 + + @throws[IOException] + override def close(): Unit = { + source.close() + deflater.end() + if (trailer != null) trailer.close() + header.close() + } + + crc.reset() +} diff --git a/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala new file mode 100644 index 0000000000..dce011dc29 --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala @@ -0,0 +1,78 @@ +package sttp.client4.compression + +import sttp.client4._ +import sttp.model.Encodings + +import Compressor._ +import java.util.zip.Deflater +import java.util.zip.DeflaterInputStream +import java.io.ByteArrayOutputStream + +class GZipDefaultCompressor[R] extends Compressor[R] { + val encoding: String = Encodings.Gzip + + def apply[R2 <: R](body: GenericRequestBody[R2]): GenericRequestBody[R] = + body match { + case NoBody => NoBody + case StringBody(s, encoding, defaultContentType) => + ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) + case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) + case ByteBufferBody(b, defaultContentType) => + ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) + case InputStreamBody(b, defaultContentType) => + InputStreamBody(new GZIPCompressingInputStream(b), defaultContentType) + case StreamBody(b) => streamsNotSupported + case FileBody(f, defaultContentType) => + InputStreamBody(new GZIPCompressingInputStream(f.openStream()), defaultContentType) + case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported + case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported + } + + private def byteArray(bytes: Array[Byte]): Array[Byte] = { + val bos = new java.io.ByteArrayOutputStream() + val gzip = new java.util.zip.GZIPOutputStream(bos) + gzip.write(bytes) + gzip.close() + bos.toByteArray() + } +} + +class DeflateDefaultCompressor[R] extends Compressor[R] { + val encoding: String = Encodings.Deflate + + def apply[R2 <: R](body: GenericRequestBody[R2]): GenericRequestBody[R] = + body match { + case NoBody => NoBody + case StringBody(s, encoding, defaultContentType) => + ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) + case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) + case ByteBufferBody(b, defaultContentType) => + ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) + case InputStreamBody(b, defaultContentType) => + InputStreamBody(new DeflaterInputStream(b), defaultContentType) + case StreamBody(b) => streamsNotSupported + case FileBody(f, defaultContentType) => + InputStreamBody(new DeflaterInputStream(f.openStream()), defaultContentType) + case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported + case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported + } + + private def byteArray(bytes: Array[Byte]): Array[Byte] = { + val deflater = new Deflater() + try { + deflater.setInput(bytes) + deflater.finish() + val byteArrayOutputStream = new ByteArrayOutputStream() + val readBuffer = new Array[Byte](1024) + + while (!deflater.finished()) { + val readCount = deflater.deflate(readBuffer) + if (readCount > 0) { + byteArrayOutputStream.write(readBuffer, 0, readCount) + } + } + + byteArrayOutputStream.toByteArray + } finally deflater.end() + } +} diff --git a/core/src/main/scalajvm/sttp/client4/compression/defaultDecompressors.scala b/core/src/main/scalajvm/sttp/client4/compression/defaultDecompressors.scala new file mode 100644 index 0000000000..198ff6ee9e --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/defaultDecompressors.scala @@ -0,0 +1,16 @@ +package sttp.client4.compression + +import sttp.model.Encodings +import java.io.InputStream +import java.util.zip.GZIPInputStream +import java.util.zip.InflaterInputStream + +object GZipInputStreamDecompressor extends Decompressor[InputStream] { + override val encoding: String = Encodings.Gzip + override def apply(body: InputStream): InputStream = new GZIPInputStream(body) +} + +object DeflateInputStreamDecompressor extends Decompressor[InputStream] { + override val encoding: String = Encodings.Deflate + override def apply(body: InputStream): InputStream = new InflaterInputStream(body) +} diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala index 1b17debc7e..53e85210be 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientAsyncBackend.scala @@ -1,11 +1,11 @@ package sttp.client4.httpclient import sttp.capabilities.{Streams, WebSockets} -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler import sttp.client4.internal.SttpToJavaConverters.{toJavaBiConsumer, toJavaFunction} import sttp.client4.internal.httpclient.{AddToQueueListener, DelegatingWebSocketListener, Sequencer, WebSocketImpl} import sttp.client4.internal.ws.{SimpleQueue, WebSocketEvent} import sttp.client4.{GenericRequest, Response, WebSocketBackend} +import sttp.client4.compression.CompressionHandlers import sttp.model.StatusCode import sttp.monad.syntax._ import sttp.monad.{Canceler, MonadAsyncError} @@ -30,8 +30,8 @@ abstract class HttpClientAsyncBackend[F[_], S <: Streams[S], BH, B]( override implicit val monad: MonadAsyncError[F], closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: EncodingHandler[B] -) extends HttpClientBackend[F, S, S with WebSockets, B](client, closeClient, customEncodingHandler) + compressionHandlers: CompressionHandlers[S, B] +) extends HttpClientBackend[F, S, S with WebSockets, B](client, closeClient, compressionHandlers) with WebSocketBackend[F] { protected def createSimpleQueue[T]: F[SimpleQueue[F, T]] protected def createSequencer: F[Sequencer[F]] diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientBackend.scala index f6caf5fae6..0549f63cbd 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientBackend.scala @@ -1,33 +1,39 @@ package sttp.client4.httpclient -import sttp.capabilities.{Effect, Streams} +import sttp.capabilities.Effect +import sttp.capabilities.Streams +import sttp.client4.Backend +import sttp.client4.BackendOptions import sttp.client4.BackendOptions.Proxy -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler +import sttp.client4.GenericBackend +import sttp.client4.GenericRequest +import sttp.client4.MultipartBody +import sttp.client4.Response +import sttp.client4.SttpClientException import sttp.client4.internal.SttpToJavaConverters.toJavaFunction -import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} -import sttp.client4.internal.ws.SimpleQueue -import sttp.client4.{ - Backend, - BackendOptions, - GenericBackend, - GenericRequest, - MultipartBody, - Response, - SttpClientException -} -import sttp.model.HttpVersion.{HTTP_1_1, HTTP_2} +import sttp.client4.internal.httpclient.BodyFromHttpClient +import sttp.client4.internal.httpclient.BodyToHttpClient import sttp.model._ +import sttp.model.HttpVersion.HTTP_1_1 +import sttp.model.HttpVersion.HTTP_2 import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.ws.WebSocket +import java.net.Authenticator import java.net.Authenticator.RequestorType -import java.net.http.{HttpClient, HttpRequest, HttpResponse, WebSocket => JWebSocket} -import java.net.{Authenticator, PasswordAuthentication} +import java.net.PasswordAuthentication +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.net.http.{WebSocket => JWebSocket} import java.time.{Duration => JDuration} -import java.util.concurrent.{Executor, ThreadPoolExecutor} +import java.util.concurrent.Executor +import java.util.concurrent.ThreadPoolExecutor import java.util.function import scala.collection.JavaConverters._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor /** @param closeClient * If the executor underlying the client is a [[ThreadPoolExecutor]], should it be shutdown on [[close]]. @@ -35,7 +41,7 @@ import scala.collection.JavaConverters._ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( client: HttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler[B] + compressionHandlers: CompressionHandlers[P, B] ) extends GenericBackend[F, P] with Backend[F] { val streams: Streams[S] @@ -56,7 +62,7 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( SttpClientException.defaultExceptionToSttpClientException(request, _) ) - protected def bodyToHttpClient: BodyToHttpClient[F, S] + protected def bodyToHttpClient: BodyToHttpClient[F, S, R] protected def bodyFromHttpClient: BodyFromHttpClient[F, S, B] private[client4] def convertRequest[T](request: GenericRequest[T, R]): F[HttpRequest] = @@ -117,8 +123,8 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( resBody.left .map { is => encoding - .filterNot(e => code.equals(StatusCode.NoContent) || request.autoDecompressionDisabled || e.isEmpty) - .map(e => customEncodingHandler.applyOrElse((is, e), standardEncoding.tupled)) + .filterNot(e => code.equals(StatusCode.NoContent) || !request.autoDecompressionEnabled || e.isEmpty) + .map(e => Decompressor.decompressIfPossible(is, e, compressionHandlers.decompressors)) .getOrElse(is) } } else { @@ -128,8 +134,6 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( monad.map(body)(Response(_, code, "", headers, Nil, request.onlyMetadata)) } - protected def standardEncoding: (B, String) => B - protected def prepareWebSocketBuilder[T]( request: GenericRequest[T, R], client: HttpClient @@ -166,8 +170,8 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( } override def close(): F[Unit] = if (closeClient) { - monad.eval( - client + monad.eval { + val _ = client .executor() .map[Unit](new function.Function[Executor, Unit] { override def apply(t: Executor): Unit = t match { @@ -175,16 +179,13 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B]( case _ => () } }) - ) + } } else { monad.unit(()) } } object HttpClientBackend { - - type EncodingHandler[B] = PartialFunction[(B, String), B] - private class ProxyAuthenticator(auth: BackendOptions.ProxyAuth) extends Authenticator { override def getPasswordAuthentication: PasswordAuthentication = if (getRequestorType == RequestorType.PROXY) { diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala index fcf538c074..045303636b 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientFutureBackend.scala @@ -1,7 +1,5 @@ package sttp.client4.httpclient -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler -import sttp.client4.httpclient.HttpClientFutureBackend.InputStreamEncodingHandler import sttp.client4.internal.httpclient._ import sttp.client4.internal.ws.{FutureSimpleQueue, SimpleQueue} import sttp.client4.internal.{emptyInputStream, NoStreams} @@ -10,34 +8,37 @@ import sttp.client4.{wrappers, BackendOptions, WebSocketBackend} import sttp.monad.{FutureMonad, MonadError} import sttp.ws.{WebSocket, WebSocketFrame} -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.net.http.HttpRequest.BodyPublisher import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.util.concurrent.Executor -import java.util.zip.{GZIPInputStream, InflaterInputStream} import scala.concurrent.{ExecutionContext, Future} +import sttp.client4.compression.Compressor +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor class HttpClientFutureBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: InputStreamEncodingHandler + compressionHandlers: CompressionHandlers[Any, InputStream] )(implicit ec: ExecutionContext) extends HttpClientAsyncBackend[Future, Nothing, InputStream, InputStream]( client, new FutureMonad, closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) { override val streams: NoStreams = NoStreams - override protected val bodyToHttpClient: BodyToHttpClient[Future, Nothing] = new BodyToHttpClient[Future, Nothing] { + override protected val bodyToHttpClient = new BodyToHttpClient[Future, Nothing, R] { override val streams: NoStreams = NoStreams override implicit val monad: MonadError[Future] = new FutureMonad override def streamToPublisher(stream: Nothing): Future[BodyPublisher] = stream // nothing is everything + override def compressors: List[Compressor[Nothing]] = compressionHandlers.compressors } override protected val bodyFromHttpClient: BodyFromHttpClient[Future, Nothing, InputStream] = @@ -57,12 +58,6 @@ class HttpClientFutureBackend private ( override protected def createSequencer: Future[Sequencer[Future]] = Future.successful(new FutureSequencer) - override protected def standardEncoding: (InputStream, String) => InputStream = { - case (body, "gzip") => new GZIPInputStream(body) - case (body, "deflate") => new InflaterInputStream(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } - override protected def createBodyHandler: HttpResponse.BodyHandler[InputStream] = BodyHandlers.ofInputStream() override protected def bodyHandlerBodyToBody(p: InputStream): InputStream = p @@ -71,42 +66,43 @@ class HttpClientFutureBackend private ( } object HttpClientFutureBackend { - type InputStreamEncodingHandler = EncodingHandler[InputStream] + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) private def apply( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: InputStreamEncodingHandler + compressionHandlers: CompressionHandlers[Any, InputStream] )(implicit ec: ExecutionContext): WebSocketBackend[Future] = wrappers.FollowRedirectsBackend( - new HttpClientFutureBackend(client, closeClient, customizeRequest, customEncodingHandler) + new HttpClientFutureBackend(client, closeClient, customizeRequest, compressionHandlers) ) def apply( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: InputStreamEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers )(implicit ec: ExecutionContext = ExecutionContext.global): WebSocketBackend[Future] = { val executor = Some(ec).collect { case executor: Executor => executor } HttpClientFutureBackend( HttpClientBackend.defaultClient(options, executor), closeClient = executor.isEmpty, customizeRequest, - customEncodingHandler + compressionHandlers ) } def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: InputStreamEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers )(implicit ec: ExecutionContext = ExecutionContext.global): WebSocketBackend[Future] = HttpClientFutureBackend( client, closeClient = false, customizeRequest, - customEncodingHandler + compressionHandlers ) /** Create a stub backend for testing, which uses [[Future]] to represent side effects, and doesn't support streaming. diff --git a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala index f5e8661138..bf41787ec9 100644 --- a/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala @@ -1,8 +1,6 @@ package sttp.client4.httpclient import sttp.capabilities.WebSockets -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler -import sttp.client4.httpclient.HttpClientSyncBackend.SyncEncodingHandler import sttp.client4.internal.httpclient._ import sttp.client4.internal.ws.{SimpleQueue, SyncQueue, WebSocketEvent} import sttp.client4.internal.{emptyInputStream, NoStreams} @@ -13,23 +11,25 @@ import sttp.monad.{IdentityMonad, MonadError} import sttp.shared.Identity import sttp.ws.{WebSocket, WebSocketFrame} -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.net.http.HttpRequest.BodyPublisher import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, WebSocketHandshakeException} import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.{ArrayBlockingQueue, CompletionException} -import java.util.zip.{GZIPInputStream, InflaterInputStream} +import sttp.client4.compression.Compressor +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor class HttpClientSyncBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: SyncEncodingHandler + compression: CompressionHandlers[Any, InputStream] ) extends HttpClientBackend[Identity, Nothing, WebSockets, InputStream]( client, closeClient, - customEncodingHandler + compression ) with WebSocketSyncBackend { @@ -63,13 +63,13 @@ class HttpClientSyncBackend private ( val isOpen: AtomicBoolean = new AtomicBoolean(false) val responseCell = new ArrayBlockingQueue[Either[Throwable, () => Response[T]]](1) - def fillCellError(t: Throwable): Unit = responseCell.add(Left(t)): Unit - def fillCell(wr: () => Response[T]): Unit = responseCell.add(Right(wr)): Unit + def fillCellError(t: Throwable): Unit = { val _ = responseCell.add(Left(t)) } + def fillCell(wr: () => Response[T]): Unit = { val _ = responseCell.add(Right(wr)) } val listener = new DelegatingWebSocketListener( new AddToQueueListener(queue, isOpen), ws => { - val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, _.get(): Unit) + val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, cf => { val _ = cf.get() }) val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata) val body = () => bodyFromHttpClient(Right(webSocket), request.response, baseResponse) fillCell(() => baseResponse.copy(body = body())) @@ -82,12 +82,12 @@ class HttpClientSyncBackend private ( responseCell.take().fold(throw _, f => f()) } - override protected val bodyToHttpClient: BodyToHttpClient[Identity, Nothing] = - new BodyToHttpClient[Identity, Nothing] { - override val streams: NoStreams = NoStreams - override implicit val monad: MonadError[Identity] = IdentityMonad - override def streamToPublisher(stream: Nothing): Identity[BodyPublisher] = stream // nothing is everything - } + override protected val bodyToHttpClient = new BodyToHttpClient[Identity, Nothing, R] { + override val streams: NoStreams = NoStreams + override implicit val monad: MonadError[Identity] = IdentityMonad + override def streamToPublisher(stream: Nothing): Identity[BodyPublisher] = stream // nothing is everything + override def compressors: List[Compressor[R]] = compression.compressors + } override protected val bodyFromHttpClient: BodyFromHttpClient[Identity, Nothing, InputStream] = new InputStreamBodyFromHttpClient[Identity, Nothing] { @@ -100,49 +100,44 @@ class HttpClientSyncBackend private ( pipe: streams.Pipe[WebSocketFrame.Data[_], WebSocketFrame] ): Identity[Unit] = pipe } - - override protected def standardEncoding: (InputStream, String) => InputStream = { - case (body, "gzip") => new GZIPInputStream(body) - case (body, "deflate") => new InflaterInputStream(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } } object HttpClientSyncBackend { - type SyncEncodingHandler = EncodingHandler[InputStream] + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) private def apply( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: SyncEncodingHandler + compressionHandlers: CompressionHandlers[Any, InputStream] ): WebSocketSyncBackend = wrappers.FollowRedirectsBackend( - new HttpClientSyncBackend(client, closeClient, customizeRequest, customEncodingHandler) + new HttpClientSyncBackend(client, closeClient, customizeRequest, compressionHandlers) ) def apply( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: SyncEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): WebSocketSyncBackend = HttpClientSyncBackend( HttpClientBackend.defaultClient(options, None), closeClient = true, customizeRequest, - customEncodingHandler + compressionHandlers ) def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: SyncEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): WebSocketSyncBackend = HttpClientSyncBackend( client, closeClient = false, customizeRequest, - customEncodingHandler + compressionHandlers ) /** Create a stub backend for testing. See [[WebSocketBackendStub]] for details on how to configure stub responses. */ diff --git a/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala b/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala index 23958292fa..a74c6a563a 100644 --- a/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala @@ -1,8 +1,8 @@ package sttp.client4.httpurlconnection import sttp.capabilities.Effect -import sttp.client4.httpurlconnection.HttpURLConnectionBackend.EncodingHandler import sttp.client4.internal._ +import sttp.client4.compression.Compressor import sttp.client4.testing.SyncBackendStub import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException} import sttp.client4.{ @@ -37,21 +37,28 @@ import java.util.concurrent.ThreadLocalRandom import java.util.zip.{GZIPInputStream, InflaterInputStream} import scala.collection.JavaConverters._ import scala.concurrent.duration.Duration +import sttp.client4.GenericRequestBody +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor class HttpURLConnectionBackend private ( opts: BackendOptions, customizeConnection: HttpURLConnection => Unit, createURL: String => URL, openConnection: (URL, Option[java.net.Proxy]) => URLConnection, - customEncodingHandler: EncodingHandler + compressionHandlers: CompressionHandlers[Any, InputStream] ) extends SyncBackend { type R = Any with Effect[Identity] override def send[T](r: GenericRequest[T, R]): Response[T] = adjustExceptions(r) { + val (body, contentLength) = Compressor.compressIfNeeded(r, compressionHandlers.compressors) + val c = openConnection(r.uri) c.setRequestMethod(r.method.method) - r.headers.foreach(h => c.setRequestProperty(h.name, h.value)) + // content-length might have changed due to compression + r.headers.foreach(h => if (!h.is(HeaderNames.ContentLength)) c.setRequestProperty(h.name, h.value)) + contentLength.foreach(cl => c.setRequestProperty(HeaderNames.ContentLength, cl.toString)) c.setDoInput(true) c.setReadTimeout(timeout(r.options.readTimeout)) c.setConnectTimeout(timeout(opts.connectionTimeout)) @@ -66,7 +73,7 @@ class HttpURLConnectionBackend private ( // we need to take care to: // (1) only call getOutputStream after the headers are set // (2) call it ony once - writeBody(r, c).foreach { os => + writeBody(body, r, c).foreach { os => os.flush() os.close() } @@ -104,8 +111,12 @@ class HttpURLConnectionBackend private ( conn.asInstanceOf[HttpURLConnection] } - private def writeBody(r: GenericRequest[_, R], c: HttpURLConnection): Option[OutputStream] = - r.body match { + private def writeBody( + body: GenericRequestBody[R], + r: GenericRequest[_, R], + c: HttpURLConnection + ): Option[OutputStream] = + body match { case NoBody => // skip None @@ -141,13 +152,13 @@ class HttpURLConnectionBackend private ( case ByteBufferBody(b, _) => val channel = Channels.newChannel(os) - channel.write(b) + val _ = channel.write(b) case InputStreamBody(b, _) => transfer(b, os) case FileBody(f, _) => - Files.copy(f.toPath, os) + val _ = Files.copy(f.toPath, os) } private val BoundaryChars = @@ -250,7 +261,7 @@ class HttpURLConnectionBackend private ( val code = StatusCode(c.getResponseCode) val wrappedIs = - if (c.getRequestMethod != "HEAD" && !code.equals(StatusCode.NoContent) && !request.autoDecompressionDisabled) { + if (c.getRequestMethod != "HEAD" && !code.equals(StatusCode.NoContent) && request.autoDecompressionEnabled) { wrapInput(contentEncoding, handleNullInput(is)) } else handleNullInput(is) val responseMetadata = ResponseMetadata(code, c.getResponseMessage, headers) @@ -295,12 +306,8 @@ class HttpURLConnectionBackend private ( private def wrapInput(contentEncoding: Option[String], is: InputStream): InputStream = contentEncoding.map(_.toLowerCase) match { - case None => is - case Some("gzip") => new GZIPInputStream(is) - case Some("deflate") => new InflaterInputStream(is) - case Some(ce) if customEncodingHandler.isDefinedAt((is, ce)) => customEncodingHandler(is -> ce) - case Some(ce) => - throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") + case None => is + case Some(encoding) => Decompressor.decompressIfPossible(is, encoding, compressionHandlers.decompressors) } private def adjustExceptions[T](request: GenericRequest[_, R])(t: => T): T = @@ -312,8 +319,8 @@ class HttpURLConnectionBackend private ( } object HttpURLConnectionBackend { - - type EncodingHandler = PartialFunction[(InputStream, String), InputStream] + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) private[client4] val defaultOpenConnection: (URL, Option[java.net.Proxy]) => URLConnection = { case (url, None) => url.openConnection() @@ -328,10 +335,10 @@ object HttpURLConnectionBackend { case (url, None) => url.openConnection() case (url, Some(proxy)) => url.openConnection(proxy) }, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): SyncBackend = wrappers.FollowRedirectsBackend( - new HttpURLConnectionBackend(options, customizeConnection, createURL, openConnection, customEncodingHandler) + new HttpURLConnectionBackend(options, customizeConnection, createURL, openConnection, compressionHandlers) ) /** Create a stub backend for testing. See [[SyncBackendStub]] for details on how to configure stub responses. diff --git a/core/src/main/scalajvm/sttp/client4/internal/SttpFileExtensions.scala b/core/src/main/scalajvm/sttp/client4/internal/SttpFileExtensions.scala index 09af97cd42..bf1e433cbb 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/SttpFileExtensions.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/SttpFileExtensions.scala @@ -4,6 +4,8 @@ import java.nio.file.Files import java.nio.file.Path import scala.io.Source +import java.io.FileInputStream +import java.io.InputStream // wrap a Path trait SttpFileExtensions { self: SttpFile => @@ -11,13 +13,14 @@ trait SttpFileExtensions { self: SttpFile => def toPath: Path = underlying.asInstanceOf[Path] def toFile: java.io.File = toPath.toFile - def readAsString: String = { + def readAsString(): String = { val s = Source.fromFile(toFile, "UTF-8"); try s.getLines().mkString("\n") finally s.close() } - - def readAsByteArray: Array[Byte] = Files.readAllBytes(toPath) + def readAsByteArray(): Array[Byte] = Files.readAllBytes(toPath) + def openStream(): InputStream = new FileInputStream(toFile) + def length(): Long = toFile.length() } trait SttpFileCompanionExtensions { diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala index 7bdc750acf..524ae7cf28 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala @@ -3,6 +3,7 @@ package sttp.client4.internal.httpclient import sttp.capabilities.Streams import sttp.client4.internal.SttpToJavaConverters.toJavaSupplier import sttp.client4.internal.{throwNestedMultipartNotAllowed, Utf8} +import sttp.client4.compression.Compressor import sttp.client4._ import sttp.model.{Header, HeaderNames, Part} import sttp.monad.MonadError @@ -16,16 +17,17 @@ import java.util.concurrent.Flow import java.util.function.Supplier import scala.collection.JavaConverters._ -private[client4] trait BodyToHttpClient[F[_], S] { +private[client4] trait BodyToHttpClient[F[_], S, R] { val streams: Streams[S] implicit def monad: MonadError[F] def apply[T]( - request: GenericRequest[T, _], + request: GenericRequest[T, R], builder: HttpRequest.Builder, contentType: Option[String] ): F[BodyPublisher] = { - val body = request.body match { + val (maybeCompressedBody, contentLength) = Compressor.compressIfNeeded(request, compressors) + val body = maybeCompressedBody match { case NoBody => BodyPublishers.noBody().unit case StringBody(b, _, _) => BodyPublishers.ofString(b).unit case ByteArrayBody(b, _) => BodyPublishers.ofByteArray(b).unit @@ -42,13 +44,14 @@ private[client4] trait BodyToHttpClient[F[_], S] { multipartBodyPublisher.build().unit } - (request.contentLength: Option[Long]) match { + contentLength match { case None => body case Some(cl) => body.map(b => withKnownContentLength(b, cl)) } } def streamToPublisher(stream: streams.BinaryStream): F[BodyPublisher] + def compressors: List[Compressor[R]] private def multipartBody[T](parts: Seq[Part[GenericRequestBody[_]]]) = { val multipartBuilder = new MultiPartBodyPublisher() diff --git a/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..9831ced732 --- /dev/null +++ b/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = Nil +} diff --git a/core/src/main/scalanative/sttp/client4/compression/DecompressorExtensions.scala b/core/src/main/scalanative/sttp/client4/compression/DecompressorExtensions.scala new file mode 100644 index 0000000000..386c2203f1 --- /dev/null +++ b/core/src/main/scalanative/sttp/client4/compression/DecompressorExtensions.scala @@ -0,0 +1,7 @@ +package sttp.client4.compression + +import java.io.InputStream + +trait DecompressorExtensions { + def defaultInputStream: List[Decompressor[InputStream]] = Nil +} diff --git a/core/src/main/scalanative/sttp/client4/internal/SttpFileExtensions.scala b/core/src/main/scalanative/sttp/client4/internal/SttpFileExtensions.scala index f5902158b8..e50dbbecaa 100644 --- a/core/src/main/scalanative/sttp/client4/internal/SttpFileExtensions.scala +++ b/core/src/main/scalanative/sttp/client4/internal/SttpFileExtensions.scala @@ -5,17 +5,21 @@ import java.nio.file.Path import scala.io.Source +import java.io.FileInputStream +import java.io.InputStream + trait SttpFileExtensions { self: SttpFile => def toPath: Path = underlying.asInstanceOf[Path] def toFile: java.io.File = toPath.toFile - def readAsString: String = { + def readAsString(): String = { val s = Source.fromFile(toFile, "UTF-8"); try s.getLines().mkString("\n") finally s.close() } - - def readAsByteArray: Array[Byte] = Files.readAllBytes(toPath) + def readAsByteArray(): Array[Byte] = Files.readAllBytes(toPath) + def openStream(): InputStream = new FileInputStream(toFile) + def length(): Long = toFile.length() } trait SttpFileCompanionExtensions { diff --git a/core/src/test/scala/sttp/client4/RequestTests.scala b/core/src/test/scala/sttp/client4/RequestTests.scala index 2e3c61f29a..27e3536b02 100644 --- a/core/src/test/scala/sttp/client4/RequestTests.scala +++ b/core/src/test/scala/sttp/client4/RequestTests.scala @@ -3,6 +3,7 @@ package sttp.client4 import sttp.model.{Header, HeaderNames, StatusCode} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import sttp.model.Encodings class RequestTests extends AnyFlatSpec with Matchers { @@ -108,4 +109,10 @@ class RequestTests extends AnyFlatSpec with Matchers { Header("Content-Length", "4") ) } + + "compressBody" should "add the content encoding header" in { + emptyRequest.compressBody(Encodings.Gzip).headers.toSet shouldBe Set( + Header(HeaderNames.ContentEncoding, Encodings.Gzip) + ) + } } diff --git a/core/src/test/scalajvm/sttp/client4/HttpURLConnectionBackendHttpTest.scala b/core/src/test/scalajvm/sttp/client4/HttpURLConnectionBackendHttpTest.scala index 5b7da964b7..2fd99f65fc 100644 --- a/core/src/test/scalajvm/sttp/client4/HttpURLConnectionBackendHttpTest.scala +++ b/core/src/test/scalajvm/sttp/client4/HttpURLConnectionBackendHttpTest.scala @@ -5,10 +5,18 @@ import sttp.client4.testing.{ConvertToFuture, HttpTest} import sttp.shared.Identity import java.io.ByteArrayInputStream +import sttp.client4.compression.Decompressor +import java.io.InputStream class HttpURLConnectionBackendHttpTest extends HttpTest[Identity] { override val backend: SyncBackend = HttpURLConnectionBackend( - customEncodingHandler = { case (_, "custom") => new ByteArrayInputStream(customEncodedData.getBytes()) } + compressionHandlers = + HttpURLConnectionBackend.DefaultCompressionHandlers.addDecompressor(new Decompressor[InputStream] { + override val encoding: String = "custom" + override def apply(inputStream: InputStream): InputStream = + new ByteArrayInputStream(customEncodedData.getBytes()) + + }) ) override implicit val convertToFuture: ConvertToFuture[Identity] = ConvertToFuture.id diff --git a/core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala b/core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala new file mode 100644 index 0000000000..dda8bde9ca --- /dev/null +++ b/core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala @@ -0,0 +1,41 @@ +package sttp.client4.compression + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks + +import java.io.ByteArrayInputStream +import java.util.zip.GZIPInputStream +import java.nio.file.Files +import java.io.File +import java.io.FileInputStream + +class GZIPCompressingInputStreamTest extends AnyFlatSpec with Matchers with ScalaCheckPropertyChecks { + implicit override val generatorDrivenConfig = + PropertyCheckConfiguration(minSuccessful = 1000, minSize = 0, sizeRange = 10000) + + it should "compress data correctly" in { + forAll { (input: Array[Byte]) => + val compressedStream = new GZIPCompressingInputStream(new ByteArrayInputStream(input)) + val decompressed = new GZIPInputStream(compressedStream).readAllBytes() + decompressed shouldEqual input + } + } + + it should "compress data from a file" in { + val testFileContent = "test file content" + withTemporaryFile(testFileContent.getBytes()) { file => + val gzipInputStream = new GZIPInputStream(new GZIPCompressingInputStream(new FileInputStream(file))) + val decompressedBytes = gzipInputStream.readAllBytes() + decompressedBytes shouldEqual testFileContent.getBytes() + } + } + + private def withTemporaryFile[T](content: Array[Byte])(f: File => T): T = { + val file = Files.createTempFile("sttp", "sttp") + Files.write(file, content) + + try f(file.toFile) + finally { val _ = Files.deleteIfExists(file) } + } +} diff --git a/core/src/test/scalajvm/sttp/client4/testing/HttpTestExtensions.scala b/core/src/test/scalajvm/sttp/client4/testing/HttpTestExtensions.scala index 9f82a32ac3..0645b6513d 100644 --- a/core/src/test/scalajvm/sttp/client4/testing/HttpTestExtensions.scala +++ b/core/src/test/scalajvm/sttp/client4/testing/HttpTestExtensions.scala @@ -12,6 +12,10 @@ import HttpTest.endpoint import org.scalatest.freespec.AsyncFreeSpecLike import sttp.client4.wrappers.{DigestAuthenticationBackend, FollowRedirectsBackend, TooManyRedirectsException} import sttp.model.headers.CookieWithMeta +import sttp.model.Encodings +import java.util.zip.GZIPInputStream +import java.io.ByteArrayInputStream +import java.util.zip.InflaterInputStream trait HttpTestExtensions[F[_]] extends AsyncFreeSpecLike { self: HttpTest[F] => protected def supportsResponseAsInputStream = true @@ -198,6 +202,59 @@ trait HttpTestExtensions[F[_]] extends AsyncFreeSpecLike { self: HttpTest[F] => } } + "compression" - { + "should compress request body using gzip" in { + val req = basicRequest + .compressBody(Encodings.Gzip) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I'm not compressed") + req.send(backend).toFuture().map { resp => + resp.code shouldBe StatusCode.Ok + + val gzipInputStream = new GZIPInputStream(new ByteArrayInputStream(resp.body)) + val decompressedBytes = gzipInputStream.readAllBytes() + + new String(decompressedBytes) shouldBe "I'm not compressed" + } + } + + "should compress request body using deflate" in { + val req = basicRequest + .compressBody(Encodings.Deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I'm not compressed") + req.send(backend).toFuture().map { resp => + resp.code shouldBe StatusCode.Ok + + val inflaterInputStream = new InflaterInputStream(new ByteArrayInputStream(resp.body)) + val decompressedBytes = inflaterInputStream.readAllBytes() + + new String(decompressedBytes) shouldBe "I'm not compressed" + } + } + + "should compress a file-based request body using deflate" in { + val testFileContent = "test file content" + withTemporaryFile(Some(testFileContent.getBytes())) { file => + val req = basicRequest + .compressBody(Encodings.Deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body(file) + req.send(backend).toFuture().map { resp => + resp.code shouldBe StatusCode.Ok + + val inflaterInputStream = new InflaterInputStream(new ByteArrayInputStream(resp.body)) + val decompressedBytes = inflaterInputStream.readAllBytes() + + new String(decompressedBytes) shouldBe testFileContent + } + } + } + } + private def withTemporaryFile[T](content: Option[Array[Byte]])(f: File => Future[T]): Future[T] = { val file = Files.createTempFile("sttp", "sttp") val result = Future { diff --git a/docs/requests/body.md b/docs/requests/body.md index e0c907713b..a9242a5aeb 100644 --- a/docs/requests/body.md +++ b/docs/requests/body.md @@ -106,3 +106,13 @@ basicRequest.body(serializePerson(Person("mary", "smith", 67))) ``` See the implementations of the `BasicBody` trait for more options. + +## Compressing bodies + +Request bodies can be compressed, using an algorithm that's supported by the backend. By default, all backends support the `gzip` and `deflate` compression algorithms. + +To compress a request body, use the `request.compressBody(encoding)` method. This will set the the `Content-Encoding` header on the request, as well as compress the body when the request is sent. If the given encoding is not supported by the backend, an exception will be thrown / a failed effect will be returned. + +Support for custom compression algorithms can be added at backend creation time, by customising the `compressionHandlers` parameter, and adding a `Compressor` implementation. Such an implementation has to specify the encoding, which it handles, as well as appropriate body transformation (which is backend-specific). + +Note that clients often don't know upfront which compression algorithms (if at all) the server supports, and that's why requests are often sent uncompressed. Sending an encoded (compressed) body, when the server doesn't support decompression, might lead to 4xx or 5xx errors. \ No newline at end of file diff --git a/docs/responses/body.md b/docs/responses/body.md index 88aa6d9f57..1f7689c9e9 100644 --- a/docs/responses/body.md +++ b/docs/responses/body.md @@ -245,3 +245,14 @@ val response: Future[Response[Either[String, Source[ByteString, Any]]]] = ``` It's also possible to parse the received stream as server-sent events (SSE), using an implementation-specific mapping function. Refer to the documentation for particular backends for more details. + +## Decompressing bodies (handling the Conent-Encoding header) + +If the response body is compressed using `gzip` or `deflate` algorithms, it will be decompressed if the `decompressResponseBody` request option is set. By default this is set to `true`, and can be disabled using the `request.disableAutoDecompression` method. + +The encoding of the response body is determined by the encodings that are accepted by the client. That's why `basicRequest` and `quickRequest` both have the `Accept-Encoding` header set to `gzip, deflate`. That's in contrast to `emptyRequest`, which has no headers set by default. + +If you'd like to use additional decompression algorithms, you'll need to: + +* amend the `Accept-Encoding` header that's set on the request +* add a decompression algorithm to the backend; that can be done on backend creation time, by customising the `compressionHandlers` parameter, and adding a `Decompressor` implementation. Such an implementation has to specify the encoding, which it handles, as well as appropriate body transformation (which is backend-specific). \ No newline at end of file diff --git a/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala b/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala index 77e7fadaaa..25da7fb66e 100644 --- a/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala +++ b/effects/cats/src/main/scalajvm/sttp/client4/httpclient/cats/HttpClientCatsBackend.scala @@ -4,35 +4,35 @@ import cats.effect.kernel.{Async, Resource, Sync} import cats.effect.std.{Dispatcher, Queue} import cats.implicits.{toFlatMapOps, toFunctorOps} import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler import sttp.client4.impl.cats.CatsMonadAsyncError import sttp.client4.internal.httpclient._ import sttp.client4.internal.ws.SimpleQueue import sttp.client4.internal.{emptyInputStream, NoStreams} import sttp.client4.testing.WebSocketBackendStub -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, WebSocketBackend} import sttp.monad.MonadError import sttp.ws.{WebSocket, WebSocketFrame} -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.net.http.HttpRequest.BodyPublisher import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, HttpResponse} -import java.util.zip.{GZIPInputStream, InflaterInputStream} +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor class HttpClientCatsBackend[F[_]: Async] private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: EncodingHandler[InputStream], + compressionHandlers: CompressionHandlers[Any, InputStream], dispatcher: Dispatcher[F] ) extends HttpClientAsyncBackend[F, Nothing, InputStream, InputStream]( client, new CatsMonadAsyncError[F], closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) { self => override protected def createSimpleQueue[T]: F[SimpleQueue[F, T]] = @@ -40,11 +40,12 @@ class HttpClientCatsBackend[F[_]: Async] private ( override protected def createSequencer: F[Sequencer[F]] = CatsSequencer.create - override protected val bodyToHttpClient: BodyToHttpClient[F, Nothing] = new BodyToHttpClient[F, Nothing] { + override protected val bodyToHttpClient: BodyToHttpClient[F, Nothing, R] = new BodyToHttpClient[F, Nothing, R] { override val streams: NoStreams = NoStreams override implicit val monad: MonadError[F] = self.monad override def streamToPublisher(stream: Nothing): F[BodyPublisher] = stream // nothing is everything + override def compressors: List[Compressor[R]] = compressionHandlers.compressors } override protected def bodyFromHttpClient: BodyFromHttpClient[F, Nothing, InputStream] = @@ -60,12 +61,6 @@ class HttpClientCatsBackend[F[_]: Async] private ( ): F[Unit] = pipe } - override protected def standardEncoding: (InputStream, String) => InputStream = { - case (body, "gzip") => new GZIPInputStream(body) - case (body, "deflate") => new InflaterInputStream(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } - override val streams: NoStreams = NoStreams override protected def createBodyHandler: HttpResponse.BodyHandler[InputStream] = BodyHandlers.ofInputStream() @@ -76,23 +71,25 @@ class HttpClientCatsBackend[F[_]: Async] private ( } object HttpClientCatsBackend { + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) private def apply[F[_]: Async]( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: EncodingHandler[InputStream], + compressionHandlers: CompressionHandlers[Any, InputStream], dispatcher: Dispatcher[F] ): WebSocketBackend[F] = wrappers.FollowRedirectsBackend( - new HttpClientCatsBackend(client, closeClient, customizeRequest, customEncodingHandler, dispatcher) + new HttpClientCatsBackend(client, closeClient, customizeRequest, compressionHandlers, dispatcher) ) def apply[F[_]: Async]( dispatcher: Dispatcher[F], options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: EncodingHandler[InputStream] = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): F[WebSocketBackend[F]] = Async[F].executor.flatMap(executor => Sync[F].delay( @@ -100,7 +97,7 @@ object HttpClientCatsBackend { HttpClientBackend.defaultClient(options, Some(executor)), closeClient = false, // we don't want to close the underlying executor customizeRequest, - customEncodingHandler, + compressionHandlers, dispatcher ) ) @@ -109,25 +106,25 @@ object HttpClientCatsBackend { def resource[F[_]: Async]( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: EncodingHandler[InputStream] = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): Resource[F, WebSocketBackend[F]] = Dispatcher .parallel[F] .flatMap(dispatcher => - Resource.make(apply(dispatcher, options, customizeRequest, customEncodingHandler))(_.close()) + Resource.make(apply(dispatcher, options, customizeRequest, compressionHandlers))(_.close()) ) def resourceUsingClient[F[_]: Async]( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: EncodingHandler[InputStream] = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): Resource[F, WebSocketBackend[F]] = Dispatcher .parallel[F] .flatMap(dispatcher => Resource.make( Sync[F].delay( - HttpClientCatsBackend(client, closeClient = true, customizeRequest, customEncodingHandler, dispatcher) + HttpClientCatsBackend(client, closeClient = true, customizeRequest, compressionHandlers, dispatcher) ) )(_.close()) ) @@ -136,9 +133,9 @@ object HttpClientCatsBackend { client: HttpClient, dispatcher: Dispatcher[F], customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: EncodingHandler[InputStream] = PartialFunction.empty + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers ): WebSocketBackend[F] = - HttpClientCatsBackend(client, closeClient = false, customizeRequest, customEncodingHandler, dispatcher) + HttpClientCatsBackend(client, closeClient = false, customizeRequest, compressionHandlers, dispatcher) /** Create a stub backend for testing, which uses the [[F]] response wrapper. * diff --git a/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/Fs2Compression.scala b/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/Fs2Compression.scala deleted file mode 100644 index ec4ab6f9cc..0000000000 --- a/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/Fs2Compression.scala +++ /dev/null @@ -1,19 +0,0 @@ -package sttp.client4.impl.fs2 - -import cats.effect.Sync -import fs2.{Pipe, Pull} - -object Fs2Compression { - - def inflateCheckHeader[F[_]: Sync]: Pipe[F, Byte, Byte] = stream => - stream.pull.uncons1 - .flatMap { - case None => Pull.done - case Some((byte, stream)) => Pull.output1((byte, stream)) - } - .stream - .flatMap { case (byte, stream) => - val wrapped = (byte & 0x0f) == 0x08 - stream.cons1(byte).through(fs2.compression.inflate(nowrap = !wrapped)) - } -} diff --git a/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/fs2Decompressors.scala b/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/fs2Decompressors.scala new file mode 100644 index 0000000000..7b716fb3e9 --- /dev/null +++ b/effects/fs2-ce2/src/main/scala/sttp/client4/impl/fs2/fs2Decompressors.scala @@ -0,0 +1,37 @@ +package sttp.client4.impl.fs2 + +import fs2.Stream +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import fs2.Pipe +import fs2.Pull +import fs2.compression.ZLibParams +import fs2.compression.InflateParams +import cats.effect.Sync + +class GZipFs2Decompressor[F[_]: Sync] extends Decompressor[Stream[F, Byte]] { + override val encoding: String = Encodings.Gzip + override def apply(body: Stream[F, Byte]): Stream[F, Byte] = + body.through(fs2.compression.gunzip()).flatMap(_.content) +} + +class DeflateFs2Decompressor[F[_]: Sync] extends Decompressor[Stream[F, Byte]] { + override val encoding: String = Encodings.Deflate + override def apply(body: Stream[F, Byte]): Stream[F, Byte] = + body.through(DeflateFs2Decompressor.inflateCheckHeader[F]) +} + +object DeflateFs2Decompressor { + def inflateCheckHeader[F[_]: Sync]: Pipe[F, Byte, Byte] = stream => + stream.pull.uncons1 + .flatMap { + case None => Pull.done + case Some((byte, stream)) => Pull.output1((byte, stream)) + } + .stream + .flatMap { case (byte, stream) => + val header = if ((byte & 0x0f) == 0x08) ZLibParams.Header.ZLIB else ZLibParams.Header.GZIP + val params = InflateParams(header = header) + stream.cons1(byte).through(fs2.compression.inflate(params)) + } +} diff --git a/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala b/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala index dc9488480e..73da9b9d40 100644 --- a/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala +++ b/effects/fs2-ce2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala @@ -1,6 +1,5 @@ package sttp.client4.httpclient.fs2 -import java.io.UnsupportedEncodingException import java.net.http.HttpRequest.BodyPublishers import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.nio.ByteBuffer @@ -13,11 +12,9 @@ import fs2.concurrent.InspectableQueue import fs2.interop.reactivestreams._ import org.reactivestreams.FlowAdapters import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler -import sttp.client4.httpclient.fs2.HttpClientFs2Backend.Fs2EncodingHandler import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} import sttp.client4.impl.cats.implicits._ -import sttp.client4.impl.fs2.{Fs2Compression, Fs2SimpleQueue} +import sttp.client4.impl.fs2.Fs2SimpleQueue import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub import sttp.client4._ @@ -28,19 +25,23 @@ import sttp.monad.MonadError import java.net.http.HttpResponse.BodyHandlers import java.util.concurrent.Flow.Publisher import scala.collection.JavaConverters._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.impl.fs2.GZipFs2Decompressor +import sttp.client4.impl.fs2.DeflateFs2Decompressor class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( client: HttpClient, blocker: Blocker, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: Fs2EncodingHandler[F] + compressionHandlers: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] ) extends HttpClientAsyncBackend[F, Fs2Streams[F], Publisher[ju.List[ByteBuffer]], Stream[F, Byte]]( client, implicitly, closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[F, Fs2Streams[F]] { self => @@ -49,8 +50,8 @@ class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( override def send[T](request: GenericRequest[T, R]): F[Response[T]] = super.send(request).guarantee(ContextShift[F].shift) - override protected val bodyToHttpClient: BodyToHttpClient[F, Fs2Streams[F]] = - new BodyToHttpClient[F, Fs2Streams[F]] { + override protected val bodyToHttpClient: BodyToHttpClient[F, Fs2Streams[F], R] = + new BodyToHttpClient[F, Fs2Streams[F], R] { override val streams: Fs2Streams[F] = Fs2Streams[F] override implicit def monad: MonadError[F] = self.monad override def streamToPublisher(stream: Stream[F, Byte]): F[HttpRequest.BodyPublisher] = @@ -59,6 +60,7 @@ class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( FlowAdapters.toFlowPublisher(stream.chunks.map(_.toByteBuffer).toUnicastPublisher) ) ) + override def compressors: List[Compressor[R]] = compressionHandlers.compressors } override protected def createBodyHandler: HttpResponse.BodyHandler[Publisher[ju.List[ByteBuffer]]] = @@ -79,33 +81,32 @@ class HttpClientFs2Backend[F[_]: ConcurrentEffect: ContextShift] private ( .flatMap(data => Stream.emits(data.asScala.map(Chunk.byteBuffer)).flatMap(Stream.chunk)) override protected def emptyBody(): Stream[F, Byte] = Stream.empty - - override protected def standardEncoding: (Stream[F, Byte], String) => Stream[F, Byte] = { - case (body, "gzip") => body.through(fs2.compression.gunzip()).flatMap(_.content) - case (body, "deflate") => body.through(Fs2Compression.inflateCheckHeader) - case (_, ce) => Stream.raiseError[F](new UnsupportedEncodingException(s"Unsupported encoding: $ce")) - } } object HttpClientFs2Backend { - type Fs2EncodingHandler[F[_]] = EncodingHandler[Stream[F, Byte]] + def defaultCompressionHandlers[F[_]: Sync]: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + CompressionHandlers( + Compressor.default[Fs2Streams[F]], + List(new GZipFs2Decompressor, new DeflateFs2Decompressor) + ) private def apply[F[_]: ConcurrentEffect: ContextShift]( client: HttpClient, blocker: Blocker, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: Fs2EncodingHandler[F] + compressionHandlers: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] ): WebSocketStreamBackend[F, Fs2Streams[F]] = FollowRedirectsBackend( - new HttpClientFs2Backend(client, blocker, closeClient, customizeRequest, customEncodingHandler) + new HttpClientFs2Backend(client, blocker, closeClient, customizeRequest, compressionHandlers) ) def apply[F[_]: ConcurrentEffect: ContextShift]( blocker: Blocker, options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Sync[F]) ): F[WebSocketStreamBackend[F, Fs2Streams[F]]] = Sync[F].delay( HttpClientFs2Backend( @@ -113,7 +114,7 @@ object HttpClientFs2Backend { blocker, closeClient = true, customizeRequest, - customEncodingHandler + compressionHandlers(implicitly) ) ) @@ -121,27 +122,32 @@ object HttpClientFs2Backend { blocker: Blocker, options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Sync[F]) ): Resource[F, WebSocketStreamBackend[F, Fs2Streams[F]]] = - Resource.make(apply(blocker, options, customizeRequest, customEncodingHandler))(_.close()) + Resource.make(apply(blocker, options, customizeRequest, compressionHandlers))(_.close()) def resourceUsingClient[F[_]: ConcurrentEffect: ContextShift]( client: HttpClient, blocker: Blocker, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Sync[F]) ): Resource[F, WebSocketStreamBackend[F, Fs2Streams[F]]] = Resource.make( - Sync[F].delay(HttpClientFs2Backend(client, blocker, closeClient = true, customizeRequest, customEncodingHandler)) + Sync[F].delay( + HttpClientFs2Backend(client, blocker, closeClient = true, customizeRequest, compressionHandlers(implicitly)) + ) )(_.close()) def usingClient[F[_]: ConcurrentEffect: ContextShift]( client: HttpClient, blocker: Blocker, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Sync[F]) ): WebSocketStreamBackend[F, Fs2Streams[F]] = - HttpClientFs2Backend(client, blocker, closeClient = false, customizeRequest, customEncodingHandler) + HttpClientFs2Backend(client, blocker, closeClient = false, customizeRequest, compressionHandlers(implicitly)) /** Create a stub backend for testing, which uses the [[F]] response wrapper, and supports `Stream[F, Byte]` * streaming. diff --git a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/Fs2Compression.scala b/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/Fs2Compression.scala deleted file mode 100644 index ebb3a11d80..0000000000 --- a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/Fs2Compression.scala +++ /dev/null @@ -1,20 +0,0 @@ -package sttp.client4.httpclient.fs2 - -import fs2.{Pipe, Pull} -import fs2.compression.{Compression, InflateParams, ZLibParams} - -object Fs2Compression { - - def inflateCheckHeader[F[_]: Compression]: Pipe[F, Byte, Byte] = stream => - stream.pull.uncons1 - .flatMap { - case None => Pull.done - case Some((byte, stream)) => Pull.output1((byte, stream)) - } - .stream - .flatMap { case (byte, stream) => - val header = if ((byte & 0x0f) == 0x08) ZLibParams.Header.ZLIB else ZLibParams.Header.GZIP - val params = InflateParams(header = header) - stream.cons1(byte).through(fs2.compression.Compression[F].inflate(params)) - } -} diff --git a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala b/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala index 38d967a350..e943b9c4cb 100644 --- a/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala +++ b/effects/fs2/src/main/scalajvm/sttp/client4/httpclient/fs2/HttpClientFs2Backend.scala @@ -1,6 +1,5 @@ package sttp.client4.httpclient.fs2 -import java.io.UnsupportedEncodingException import java.net.http.HttpRequest.BodyPublishers import java.net.http.{HttpClient, HttpRequest, HttpResponse} import java.nio.ByteBuffer @@ -12,8 +11,6 @@ import fs2.interop.reactivestreams.{PublisherOps, StreamUnicastPublisher} import fs2.{Chunk, Stream} import org.reactivestreams.FlowAdapters import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler -import sttp.client4.httpclient.fs2.HttpClientFs2Backend.Fs2EncodingHandler import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} import sttp.client4.impl.cats.implicits._ import sttp.client4.impl.fs2.Fs2SimpleQueue @@ -28,28 +25,33 @@ import java.net.http.HttpResponse.BodyHandlers import java.util.concurrent.Flow.Publisher import java.{util => ju} import scala.collection.JavaConverters._ +import sttp.client4.compression.Compressor +import sttp.client4.impl.fs2.{DeflateFs2Compressor, DeflateFs2Decompressor, GZipFs2Compressor, GZipFs2Decompressor} +import sttp.client4.compression.CompressionHandlers +import fs2.compression.Compression class HttpClientFs2Backend[F[_]: Async] private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: Fs2EncodingHandler[F], + compressionHandlers: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]], dispatcher: Dispatcher[F] ) extends HttpClientAsyncBackend[F, Fs2Streams[F], Publisher[ju.List[ByteBuffer]], Stream[F, Byte]]( client, implicitly, closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[F, Fs2Streams[F]] { self => override val streams: Fs2Streams[F] = Fs2Streams[F] - override protected val bodyToHttpClient: BodyToHttpClient[F, Fs2Streams[F]] = - new BodyToHttpClient[F, Fs2Streams[F]] { + override protected val bodyToHttpClient: BodyToHttpClient[F, Fs2Streams[F], R] = + new BodyToHttpClient[F, Fs2Streams[F], R] { override val streams: Fs2Streams[F] = Fs2Streams[F] override implicit def monad: MonadError[F] = self.monad + override def compressors: List[Compressor[R]] = compressionHandlers.compressors override def streamToPublisher(stream: Stream[F, Byte]): F[HttpRequest.BodyPublisher] = monad.eval( BodyPublishers.fromPublisher( @@ -80,33 +82,32 @@ class HttpClientFs2Backend[F[_]: Async] private ( .flatMap(data => Stream.emits(data.asScala.map(Chunk.byteBuffer)).flatMap(Stream.chunk)) override protected def emptyBody(): Stream[F, Byte] = Stream.empty - - override protected def standardEncoding: (Stream[F, Byte], String) => Stream[F, Byte] = { - case (body, "gzip") => body.through(fs2.compression.Compression[F].gunzip()).flatMap(_.content) - case (body, "deflate") => body.through(Fs2Compression.inflateCheckHeader[F]) - case (_, ce) => Stream.raiseError[F](new UnsupportedEncodingException(s"Unsupported encoding: $ce")) - } } object HttpClientFs2Backend { - type Fs2EncodingHandler[F[_]] = EncodingHandler[Stream[F, Byte]] + def defaultCompressionHandlers[F[_]: Async]: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + CompressionHandlers( + List(new GZipFs2Compressor[F, Fs2Streams[F]](), new DeflateFs2Compressor[F, Fs2Streams[F]]()), + List(new GZipFs2Decompressor, new DeflateFs2Decompressor) + ) private def apply[F[_]: Async]( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: Fs2EncodingHandler[F], + compressionHandlers: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]], dispatcher: Dispatcher[F] ): WebSocketStreamBackend[F, Fs2Streams[F]] = FollowRedirectsBackend( - new HttpClientFs2Backend(client, closeClient, customizeRequest, customEncodingHandler, dispatcher) + new HttpClientFs2Backend(client, closeClient, customizeRequest, compressionHandlers, dispatcher) ) def apply[F[_]: Async]( dispatcher: Dispatcher[F], options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Async[F]) ): F[WebSocketStreamBackend[F, Fs2Streams[F]]] = Async[F].executor.flatMap(executor => Sync[F].delay( @@ -114,7 +115,7 @@ object HttpClientFs2Backend { HttpClientBackend.defaultClient(options, Some(executor)), closeClient = false, // we don't want to close the underlying executor customizeRequest, - customEncodingHandler, + compressionHandlers(implicitly), dispatcher ) ) @@ -123,24 +124,27 @@ object HttpClientFs2Backend { def resource[F[_]: Async]( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, WebSocketStreamBackend[F, Fs2Streams[F]]] = Dispatcher .parallel[F] .flatMap(dispatcher => - Resource.make(apply(dispatcher, options, customizeRequest, customEncodingHandler))(_.close()) + Resource.make(apply(dispatcher, options, customizeRequest, compressionHandlers))(_.close()) ) def resourceUsingClient[F[_]: Async]( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, WebSocketStreamBackend[F, Fs2Streams[F]]] = Dispatcher .parallel[F] .flatMap(dispatcher => Resource.make( - Sync[F].delay(apply(client, closeClient = true, customizeRequest, customEncodingHandler, dispatcher)) + Sync[F] + .delay(apply(client, closeClient = true, customizeRequest, compressionHandlers(implicitly), dispatcher)) )(_.close()) ) @@ -148,9 +152,10 @@ object HttpClientFs2Backend { client: HttpClient, dispatcher: Dispatcher[F], customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: Fs2EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + defaultCompressionHandlers[F](_: Async[F]) ): WebSocketStreamBackend[F, Fs2Streams[F]] = - HttpClientFs2Backend(client, closeClient = false, customizeRequest, customEncodingHandler, dispatcher) + HttpClientFs2Backend(client, closeClient = false, customizeRequest, compressionHandlers(implicitly), dispatcher) /** Create a stub backend for testing, which uses the [[F]] response wrapper, and supports `Stream[F, Byte]` * streaming. diff --git a/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Compressor.scala b/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Compressor.scala new file mode 100644 index 0000000000..15ab608cbe --- /dev/null +++ b/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Compressor.scala @@ -0,0 +1,48 @@ +package sttp.client4.impl.fs2 + +import sttp.client4._ +import sttp.client4.GenericRequestBody +import fs2._ +import fs2.compression.Compression +import cats.syntax.all._ +import fs2.io.file.Files +import cats.effect.Sync +import sttp.capabilities.fs2.Fs2Streams +import fs2.compression.DeflateParams +import sttp.client4.compression.{Compressor, DeflateDefaultCompressor, GZipDefaultCompressor} + +trait Fs2Compressor[F[_], R <: Fs2Streams[F]] extends Compressor[R] { + protected val fSync: Sync[F] + protected val fFiles: Files[F] + + override abstract def apply[R2 <: R](body: GenericRequestBody[R2]): GenericRequestBody[R] = + body match { + case InputStreamBody(b, _) => + StreamBody(Fs2Streams[F])(compressStream(fs2.io.readInputStream(b.pure[F](fSync), 1024)(fSync))) + case StreamBody(b) => StreamBody(Fs2Streams[F])(compressStream(b.asInstanceOf[fs2.Stream[F, Byte]])) + case FileBody(f, _) => StreamBody(Fs2Streams[F])(compressStream(Files[F](fFiles).readAll(f.toPath, 1024))) + case _ => super.apply(body) + } + + def compressStream(stream: fs2.Stream[F, Byte]): fs2.Stream[F, Byte] +} + +class GZipFs2Compressor[F[_]: Compression: Sync: Files, R <: Fs2Streams[F]] + extends GZipDefaultCompressor[R] + with Fs2Compressor[F, R] { + + override protected val fSync: Sync[F] = implicitly + override protected val fFiles: Files[F] = implicitly + + def compressStream(stream: Stream[F, Byte]): Stream[F, Byte] = stream.through(fs2.compression.Compression[F].gzip()) +} + +class DeflateFs2Compressor[F[_]: Compression: Sync: Files, R <: Fs2Streams[F]] + extends DeflateDefaultCompressor[R] + with Fs2Compressor[F, R] { + override protected val fSync: Sync[F] = implicitly + override protected val fFiles: Files[F] = implicitly + + def compressStream(stream: Stream[F, Byte]): Stream[F, Byte] = + stream.through(fs2.compression.Compression[F].deflate(DeflateParams())) +} diff --git a/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Decompressors.scala b/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Decompressors.scala new file mode 100644 index 0000000000..f5cdafa789 --- /dev/null +++ b/effects/fs2/src/main/scalajvm/sttp/client4/impl/fs2/fs2Decompressors.scala @@ -0,0 +1,34 @@ +package sttp.client4.impl.fs2 + +import fs2.compression.Compression +import fs2.Stream +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import fs2.Pipe +import fs2.Pull +import fs2.compression.ZLibParams +import fs2.compression.InflateParams + +class GZipFs2Decompressor[F[_]: Compression] extends Decompressor[Stream[F, Byte]] { + override val encoding: String = Encodings.Gzip + override def apply(body: Stream[F, Byte]): Stream[F, Byte] = + body.through(fs2.compression.Compression[F].gunzip()).flatMap(_.content) +} + +class DeflateFs2Decompressor[F[_]: Compression] extends Decompressor[Stream[F, Byte]] { + override val encoding: String = Encodings.Deflate + override def apply(body: Stream[F, Byte]): Stream[F, Byte] = body.through(inflateCheckHeader[F]) + + private def inflateCheckHeader[F[_]: Compression]: Pipe[F, Byte, Byte] = stream => + stream.pull.uncons1 + .flatMap { + case None => Pull.done + case Some((byte, stream)) => Pull.output1((byte, stream)) + } + .stream + .flatMap { case (byte, stream) => + val header = if ((byte & 0x0f) == 0x08) ZLibParams.Header.ZLIB else ZLibParams.Header.GZIP + val params = InflateParams(header = header) + stream.cons1(byte).through(fs2.compression.Compression[F].inflate(params)) + } +} diff --git a/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala index a76b436302..baa2e5792b 100644 --- a/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala +++ b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/HttpClientMonixBackend.scala @@ -4,22 +4,17 @@ import cats.effect.Resource import monix.eval.Task import monix.execution.Scheduler import monix.reactive.Observable -import monix.reactive.compression._ import org.reactivestreams.FlowAdapters import sttp.capabilities.monix.MonixStreams import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler -import sttp.client4.httpclient.monix.HttpClientMonixBackend.MonixEncodingHandler import sttp.client4.impl.monix.{MonixSimpleQueue, TaskMonadAsyncError} import sttp.client4.internal._ import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, WebSocketStreamBackend} import sttp.monad.MonadError -import java.io.UnsupportedEncodingException import java.net.http.HttpRequest.BodyPublishers import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, HttpResponse} @@ -27,32 +22,35 @@ import java.nio.ByteBuffer import java.util.concurrent.Flow.Publisher import java.{util => ju} import scala.collection.JavaConverters._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor class HttpClientMonixBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: MonixEncodingHandler + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] )(implicit s: Scheduler) extends HttpClientAsyncBackend[Task, MonixStreams, Publisher[ju.List[ByteBuffer]], MonixStreams.BinaryStream]( client, TaskMonadAsyncError, closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[Task, MonixStreams] { self => override val streams: MonixStreams = MonixStreams - override protected val bodyToHttpClient: BodyToHttpClient[Task, MonixStreams] = - new BodyToHttpClient[Task, MonixStreams] { + override protected val bodyToHttpClient: BodyToHttpClient[Task, MonixStreams, R] = + new BodyToHttpClient[Task, MonixStreams, R] { override val streams: MonixStreams = MonixStreams override implicit def monad: MonadError[Task] = self.monad override def streamToPublisher(stream: Observable[Array[Byte]]): Task[HttpRequest.BodyPublisher] = monad.eval( BodyPublishers.fromPublisher(FlowAdapters.toFlowPublisher(stream.map(ByteBuffer.wrap).toReactivePublisher)) ) + override def compressors: List[Compressor[R]] = compressionHandlers.compressors } override protected val bodyFromHttpClient: BodyFromHttpClient[Task, MonixStreams, MonixStreams.BinaryStream] = @@ -76,33 +74,31 @@ class HttpClientMonixBackend private ( .map(_.safeRead()) override protected def emptyBody(): Observable[Array[Byte]] = Observable.empty - - override protected def standardEncoding: (Observable[Array[Byte]], String) => Observable[Array[Byte]] = { - case (body, "gzip") => body.transform(gunzip()) - case (body, "deflate") => body.transform(inflate()) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } } object HttpClientMonixBackend { - type MonixEncodingHandler = EncodingHandler[MonixStreams.BinaryStream] + val DefaultCompressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] = + CompressionHandlers( + Compressor.default[MonixStreams], + List(GZipMonixDecompressor, DeflateMonixDecompressor) + ) private def apply( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: MonixEncodingHandler + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] )(implicit s: Scheduler ): WebSocketStreamBackend[Task, MonixStreams] = wrappers.FollowRedirectsBackend( - new HttpClientMonixBackend(client, closeClient, customizeRequest, customEncodingHandler)(s) + new HttpClientMonixBackend(client, closeClient, customizeRequest, compressionHandlers)(s) ) def apply( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: MonixEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] = DefaultCompressionHandlers )(implicit s: Scheduler = Scheduler.global ): Task[WebSocketStreamBackend[Task, MonixStreams]] = @@ -111,36 +107,36 @@ object HttpClientMonixBackend { HttpClientBackend.defaultClient(options, Some(s)), closeClient = false, // we don't want to close Monix's scheduler customizeRequest, - customEncodingHandler + compressionHandlers )(s) ) def resource( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: MonixEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] = DefaultCompressionHandlers )(implicit s: Scheduler = Scheduler.global ): Resource[Task, WebSocketStreamBackend[Task, MonixStreams]] = - Resource.make(apply(options, customizeRequest, customEncodingHandler))(_.close()) + Resource.make(apply(options, customizeRequest, compressionHandlers))(_.close()) def resourceUsingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: MonixEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] = DefaultCompressionHandlers )(implicit s: Scheduler = Scheduler.global ): Resource[Task, WebSocketStreamBackend[Task, MonixStreams]] = Resource.make( - Task.eval(HttpClientMonixBackend(client, closeClient = true, customizeRequest, customEncodingHandler)(s)) + Task.eval(HttpClientMonixBackend(client, closeClient = true, customizeRequest, compressionHandlers)(s)) )(_.close()) def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: MonixEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[MonixStreams, MonixStreams.BinaryStream] = DefaultCompressionHandlers )(implicit s: Scheduler = Scheduler.global): WebSocketStreamBackend[Task, MonixStreams] = - HttpClientMonixBackend(client, closeClient = false, customizeRequest, customEncodingHandler)(s) + HttpClientMonixBackend(client, closeClient = false, customizeRequest, compressionHandlers)(s) /** Create a stub backend for testing, which uses the [[Task]] response wrapper, and supports `Observable[ByteBuffer]` * streaming. diff --git a/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/monixDecompressors.scala b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/monixDecompressors.scala new file mode 100644 index 0000000000..46f41255ea --- /dev/null +++ b/effects/monix/src/main/scalajvm/sttp/client4/httpclient/monix/monixDecompressors.scala @@ -0,0 +1,16 @@ +package sttp.client4.httpclient.monix + +import sttp.capabilities.monix.MonixStreams +import sttp.model.Encodings +import sttp.client4.compression.Decompressor +import monix.reactive.compression._ + +object GZipMonixDecompressor extends Decompressor[MonixStreams.BinaryStream] { + override val encoding: String = Encodings.Gzip + override def apply(body: MonixStreams.BinaryStream): MonixStreams.BinaryStream = body.transform(gunzip()) +} + +object DeflateMonixDecompressor extends Decompressor[MonixStreams.BinaryStream] { + override val encoding: String = Encodings.Deflate + override def apply(body: MonixStreams.BinaryStream): MonixStreams.BinaryStream = body.transform(inflate()) +} diff --git a/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala b/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala index 2bcc37f856..1470c1f6b9 100644 --- a/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala +++ b/effects/zio/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala @@ -4,20 +4,17 @@ import _root_.zio.interop.reactivestreams._ import org.reactivestreams.FlowAdapters import sttp.capabilities.zio.ZioStreams import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler import sttp.client4.impl.zio.{RIOMonadAsyncError, ZioSimpleQueue} import sttp.client4.internal._ import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, GenericRequest, Response, WebSocketStreamBackend} import sttp.monad.MonadError import zio.Chunk.ByteArray import zio._ -import zio.stream.{ZPipeline, ZSink, ZStream} +import zio.stream.ZStream -import java.io.UnsupportedEncodingException import java.net.http.HttpRequest.{BodyPublisher, BodyPublishers} import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, HttpResponse} @@ -25,12 +22,15 @@ import java.nio.ByteBuffer import java.util import java.util.concurrent.Flow.Publisher import java.{util => ju} +import sttp.client4.compression.Compressor +import sttp.client4.impl.zio.{DeflateZioCompressor, DeflateZioDecompressor, GZipZioCompressor, GZipZioDecompressor} +import sttp.client4.compression.CompressionHandlers class HttpClientZioBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: EncodingHandler[ZioStreams.BinaryStream] + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] ) extends HttpClientAsyncBackend[ Task, ZioStreams, @@ -41,7 +41,7 @@ class HttpClientZioBackend private ( new RIOMonadAsyncError[Any], closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[Task, ZioStreams] { self => @@ -58,8 +58,8 @@ class HttpClientZioBackend private ( ByteArray(a, 0, a.length) } - override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams] = - new BodyToHttpClient[Task, ZioStreams] { + override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams, R] = + new BodyToHttpClient[Task, ZioStreams, R] { override val streams: ZioStreams = ZioStreams override implicit def monad: MonadError[Task] = self.monad override def streamToPublisher(stream: ZStream[Any, Throwable, Byte]): Task[BodyPublisher] = { @@ -69,6 +69,7 @@ class HttpClientZioBackend private ( BodyPublishers.fromPublisher(FlowAdapters.toFlowPublisher(pub)) } } + override def compressors: List[Compressor[R]] = compressionHandlers.compressors } override def send[T](request: GenericRequest[T, R]): Task[Response[T]] = @@ -84,36 +85,29 @@ class HttpClientZioBackend private ( } yield new ZioSimpleQueue(queue, runtime) override protected def createSequencer: Task[Sequencer[Task]] = ZioSequencer.create - - override protected def standardEncoding: (ZStream[Any, Throwable, Byte], String) => ZStream[Any, Throwable, Byte] = { - case (body, "gzip") => body.via(ZPipeline.gunzip()) - case (body, "deflate") => - ZStream.scoped(body.peel(ZSink.take[Byte](1))).flatMap { case (chunk, stream) => - val wrapped = chunk.headOption.exists(byte => (byte & 0x0f) == 0x08) - (ZStream.fromChunk(chunk) ++ stream).via(ZPipeline.inflate(noWrap = !wrapped)) - } - case (_, ce) => ZStream.fail(new UnsupportedEncodingException(s"Unsupported encoding: $ce")) - } } object HttpClientZioBackend { - - type ZioEncodingHandler = EncodingHandler[ZioStreams.BinaryStream] + val DefaultCompressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = + CompressionHandlers( + List(GZipZioCompressor, DeflateZioCompressor), + List(GZipZioDecompressor, DeflateZioDecompressor) + ) private def apply( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: ZioEncodingHandler + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] ): WebSocketStreamBackend[Task, ZioStreams] = wrappers.FollowRedirectsBackend( - new HttpClientZioBackend(client, closeClient, customizeRequest, customEncodingHandler) + new HttpClientZioBackend(client, closeClient, customizeRequest, compressionHandlers) ) def apply( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): Task[WebSocketStreamBackend[Task, ZioStreams]] = ZIO.executor.flatMap(executor => ZIO.attempt( @@ -121,7 +115,7 @@ object HttpClientZioBackend { HttpClientBackend.defaultClient(options, Some(executor.asJava)), closeClient = false, // we don't want to close ZIO's executor customizeRequest, - customEncodingHandler + compressionHandlers ) ) ) @@ -129,32 +123,32 @@ object HttpClientZioBackend { def scoped( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZIO[Scope, Throwable, WebSocketStreamBackend[Task, ZioStreams]] = - ZIO.acquireRelease(apply(options, customizeRequest, customEncodingHandler))( + ZIO.acquireRelease(apply(options, customizeRequest, compressionHandlers))( _.close().ignore ) def scopedUsingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZIO[Scope, Throwable, WebSocketStreamBackend[Task, ZioStreams]] = ZIO.acquireRelease( - ZIO.attempt(HttpClientZioBackend(client, closeClient = true, customizeRequest, customEncodingHandler)) + ZIO.attempt(HttpClientZioBackend(client, closeClient = true, customizeRequest, compressionHandlers)) )(_.close().ignore) def layer( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZLayer[Any, Throwable, SttpClient] = ZLayer.scoped( (for { backend <- HttpClientZioBackend( options, customizeRequest, - customEncodingHandler + compressionHandlers ) } yield backend).tap(client => ZIO.addFinalizer(client.close().ignore)) ) @@ -162,19 +156,19 @@ object HttpClientZioBackend { def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): WebSocketStreamBackend[Task, ZioStreams] = HttpClientZioBackend( client, closeClient = false, customizeRequest, - customEncodingHandler + compressionHandlers ) def layerUsingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZLayer[Any, Throwable, SttpClient] = ZLayer.scoped( ZIO @@ -183,7 +177,7 @@ object HttpClientZioBackend { usingClient( client, customizeRequest, - customEncodingHandler + compressionHandlers ) ) )(_.close().ignore) diff --git a/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioCompressor.scala b/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioCompressor.scala new file mode 100644 index 0000000000..edc3b73758 --- /dev/null +++ b/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioCompressor.scala @@ -0,0 +1,31 @@ +package sttp.client4.impl.zio + +import sttp.client4._ +import sttp.client4.compression.Compressor +import sttp.capabilities.zio.ZioStreams + +import zio.stream.Stream +import sttp.client4.compression.GZipDefaultCompressor +import sttp.client4.compression.DeflateDefaultCompressor +import zio.stream.ZPipeline +import zio.stream.ZStream + +trait ZioCompressor extends Compressor[ZioStreams] { + override abstract def apply[R2 <: ZioStreams](body: GenericRequestBody[R2]): GenericRequestBody[ZioStreams] = + body match { + case InputStreamBody(b, _) => StreamBody(ZioStreams)(compressStream(ZStream.fromInputStream(b))) + case StreamBody(b) => StreamBody(ZioStreams)(compressStream(b.asInstanceOf[Stream[Throwable, Byte]])) + case FileBody(f, _) => StreamBody(ZioStreams)(compressStream(ZStream.fromFile(f.toFile))) + case _ => super.apply(body) + } + + def compressStream(stream: Stream[Throwable, Byte]): Stream[Throwable, Byte] +} + +object GZipZioCompressor extends GZipDefaultCompressor[ZioStreams] with ZioCompressor { + def compressStream(stream: Stream[Throwable, Byte]): Stream[Throwable, Byte] = stream.via(ZPipeline.gzip()) +} + +object DeflateZioCompressor extends DeflateDefaultCompressor[ZioStreams] with ZioCompressor { + def compressStream(stream: Stream[Throwable, Byte]): Stream[Throwable, Byte] = stream.via(ZPipeline.deflate()) +} diff --git a/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioDecompressors.scala b/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioDecompressors.scala new file mode 100644 index 0000000000..1894b7bcd2 --- /dev/null +++ b/effects/zio/src/main/scalajvm/sttp/client4/impl/zio/zioDecompressors.scala @@ -0,0 +1,22 @@ +package sttp.client4.impl.zio + +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import sttp.capabilities.zio.ZioStreams +import zio.stream.ZPipeline +import zio.stream.ZStream +import zio.stream.ZSink + +object GZipZioDecompressor extends Decompressor[ZioStreams.BinaryStream] { + override val encoding: String = Encodings.Gzip + override def apply(body: ZioStreams.BinaryStream): ZioStreams.BinaryStream = body.via(ZPipeline.gunzip()) +} + +object DeflateZioDecompressor extends Decompressor[ZioStreams.BinaryStream] { + override val encoding: String = Encodings.Deflate + override def apply(body: ZioStreams.BinaryStream): ZioStreams.BinaryStream = + ZStream.scoped[Any](body.peel(ZSink.take[Byte](1))).flatMap { case (chunk, stream) => + val wrapped = chunk.headOption.exists(byte => (byte & 0x0f) == 0x08) + (ZStream.fromChunk(chunk) ++ stream).via(ZPipeline.inflate(noWrap = !wrapped)) + } +} diff --git a/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala index bfa558352d..0b0b2b4084 100644 --- a/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala +++ b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/HttpClientZioBackend.scala @@ -4,20 +4,17 @@ import _root_.zio.interop.reactivestreams._ import org.reactivestreams.FlowAdapters import sttp.capabilities.zio.ZioStreams import sttp.client4.httpclient.{HttpClientAsyncBackend, HttpClientBackend} -import sttp.client4.httpclient.HttpClientBackend.EncodingHandler import sttp.client4.impl.zio.{RIOMonadAsyncError, ZioSimpleQueue} import sttp.client4.internal._ import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer} import sttp.client4.internal.ws.SimpleQueue import sttp.client4.testing.WebSocketStreamBackendStub -import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, GenericRequest, Response, WebSocketStreamBackend} import sttp.monad.MonadError import zio.Chunk.ByteArray import zio._ -import zio.stream.{ZSink, ZStream, ZTransducer} +import zio.stream.ZStream -import java.io.UnsupportedEncodingException import java.net.http.HttpRequest.{BodyPublisher, BodyPublishers} import java.net.http.HttpResponse.BodyHandlers import java.net.http.{HttpClient, HttpRequest, HttpResponse} @@ -26,12 +23,14 @@ import java.util import java.util.concurrent.Flow.Publisher import java.{util => ju} import scala.collection.JavaConverters._ +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor class HttpClientZioBackend private ( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: EncodingHandler[ZioStreams.BinaryStream] + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] ) extends HttpClientAsyncBackend[ Task, ZioStreams, @@ -42,7 +41,7 @@ class HttpClientZioBackend private ( new RIOMonadAsyncError[Any], closeClient, customizeRequest, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[Task, ZioStreams] { self => @@ -59,8 +58,8 @@ class HttpClientZioBackend private ( .toStream() .mapConcatChunk(list => ByteArray(list.asScala.toList.flatMap(_.safeRead()).toArray)) - override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams] = - new BodyToHttpClient[Task, ZioStreams] { + override protected val bodyToHttpClient: BodyToHttpClient[Task, ZioStreams, R] = + new BodyToHttpClient[Task, ZioStreams, R] { override val streams: ZioStreams = ZioStreams override implicit def monad: MonadError[Task] = self.monad override def streamToPublisher(stream: ZStream[Any, Throwable, Byte]): Task[BodyPublisher] = { @@ -70,6 +69,7 @@ class HttpClientZioBackend private ( BodyPublishers.fromPublisher(FlowAdapters.toFlowPublisher(pub)) } } + override def compressors: List[Compressor[R]] = compressionHandlers.compressors } override def send[T](request: GenericRequest[T, R]): Task[Response[T]] = @@ -85,36 +85,29 @@ class HttpClientZioBackend private ( } yield new ZioSimpleQueue(queue, runtime) override protected def createSequencer: Task[Sequencer[Task]] = ZioSequencer.create - - override protected def standardEncoding: (ZStream[Any, Throwable, Byte], String) => ZStream[Any, Throwable, Byte] = { - case (body, "gzip") => body.transduce(ZTransducer.gunzip()) - case (body, "deflate") => - ZStream.managed(body.peel(ZSink.take[Byte](1))).flatMap { case (chunk, stream) => - val wrapped = chunk.headOption.exists(byte => (byte & 0x0f) == 0x08) - (ZStream.fromChunk(chunk) ++ stream).transduce(ZTransducer.inflate(noWrap = !wrapped)) - } - case (_, ce) => ZStream.fail(new UnsupportedEncodingException(s"Unsupported encoding: $ce")) - } } object HttpClientZioBackend { - - type ZioEncodingHandler = EncodingHandler[ZioStreams.BinaryStream] + val DefaultCompressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = + CompressionHandlers( + Compressor.default, + List(GZipZioDecompressor, DeflateZioDecompressor) + ) private def apply( client: HttpClient, closeClient: Boolean, customizeRequest: HttpRequest => HttpRequest, - customEncodingHandler: ZioEncodingHandler + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] ): WebSocketStreamBackend[Task, ZioStreams] = wrappers.FollowRedirectsBackend( - new HttpClientZioBackend(client, closeClient, customizeRequest, customEncodingHandler) + new HttpClientZioBackend(client, closeClient, customizeRequest, compressionHandlers) ) def apply( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): Task[WebSocketStreamBackend[Task, ZioStreams]] = UIO.executor.flatMap(executor => Task.effect( @@ -122,7 +115,7 @@ object HttpClientZioBackend { HttpClientBackend.defaultClient(options, Some(executor.asJava)), closeClient = false, // we don't want to close ZIO's executor customizeRequest, - customEncodingHandler + compressionHandlers ) ) ) @@ -130,23 +123,23 @@ object HttpClientZioBackend { def managed( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZManaged[Any, Throwable, WebSocketStreamBackend[Task, ZioStreams]] = - ZManaged.make(apply(options, customizeRequest, customEncodingHandler))( + ZManaged.make(apply(options, customizeRequest, compressionHandlers))( _.close().ignore ) def layer( options: BackendOptions = BackendOptions.Default, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZLayer[Any, Throwable, SttpClient] = ZLayer.fromManaged( (for { backend <- HttpClientZioBackend( options, customizeRequest, - customEncodingHandler + compressionHandlers ) } yield backend).toManaged(_.close().ignore) ) @@ -154,19 +147,19 @@ object HttpClientZioBackend { def usingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): WebSocketStreamBackend[Task, ZioStreams] = HttpClientZioBackend( client, closeClient = false, customizeRequest, - customEncodingHandler + compressionHandlers ) def layerUsingClient( client: HttpClient, customizeRequest: HttpRequest => HttpRequest = identity, - customEncodingHandler: ZioEncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[ZioStreams, ZioStreams.BinaryStream] = DefaultCompressionHandlers ): ZLayer[Any, Throwable, SttpClient] = ZLayer.fromManaged( ZManaged @@ -174,7 +167,7 @@ object HttpClientZioBackend { usingClient( client, customizeRequest, - customEncodingHandler + compressionHandlers ) )(_.close().ignore) ) diff --git a/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/zioDecompressors.scala b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/zioDecompressors.scala new file mode 100644 index 0000000000..cda7d5ec45 --- /dev/null +++ b/effects/zio1/src/main/scalajvm/sttp/client4/httpclient/zio/zioDecompressors.scala @@ -0,0 +1,22 @@ +package sttp.client4.httpclient.zio + +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import sttp.capabilities.zio.ZioStreams +import zio.stream.ZStream +import zio.stream.ZSink +import zio.stream.ZTransducer + +object GZipZioDecompressor extends Decompressor[ZioStreams.BinaryStream] { + override val encoding: String = Encodings.Gzip + override def apply(body: ZioStreams.BinaryStream): ZioStreams.BinaryStream = body.transduce(ZTransducer.gunzip()) +} + +object DeflateZioDecompressor extends Decompressor[ZioStreams.BinaryStream] { + override val encoding: String = Encodings.Deflate + override def apply(body: ZioStreams.BinaryStream): ZioStreams.BinaryStream = + ZStream.managed(body.peel(ZSink.take[Byte](1))).flatMap { case (chunk, stream) => + val wrapped = chunk.headOption.exists(byte => (byte & 0x0f) == 0x08) + (ZStream.fromChunk(chunk) ++ stream).transduce(ZTransducer.inflate(noWrap = !wrapped)) + } +} diff --git a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala index 86325267c2..4025ab2db5 100644 --- a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala +++ b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala @@ -24,11 +24,17 @@ import sttp.model.HttpVersion.HTTP_1 import sttp.model._ import sttp.monad.MonadError import sttp.monad.syntax._ +import sttp.client4.compression.Compressor +import sttp.client4.compression.DeflateDefaultCompressor +import sttp.client4.compression.GZipDefaultCompressor import scala.io.Source class FinagleBackend(client: Option[Client] = None) extends Backend[TFuture] { type R = Any with Effect[TFuture] + + private val compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor) + override def send[T](request: GenericRequest[T, R]): TFuture[Response[T]] = adjustExceptions(request) { val service = getClient(client, request) @@ -71,12 +77,16 @@ class FinagleBackend(client: Option[Client] = None) extends Backend[TFuture] { case _ => FMethod(m.method) } - private def requestBodyToFinagle(r: GenericRequest[_, Nothing]): http.Request = { + private def requestBodyToFinagle(r: GenericRequest[_, R]): http.Request = { val finagleMethod = methodToFinagle(r.method) val url = r.uri.toString - val headers = headersToMap(r.headers) + val (body, contentLength) = Compressor.compressIfNeeded(r, compressors) + val headers = { + val hh = headersToMap(r.headers) - HeaderNames.ContentLength + contentLength.fold(hh)(cl => hh.updated(HeaderNames.ContentLength, cl.toString)) + } - r.body match { + body match { case FileBody(f, _) => val content: String = Source.fromFile(f.toFile).mkString buildRequest(url, headers, finagleMethod, Some(ByteArray(content.getBytes: _*)), r.httpVersion) diff --git a/http4s-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala b/http4s-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala index 041057a1ca..0856709b0e 100644 --- a/http4s-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala +++ b/http4s-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala @@ -1,48 +1,65 @@ package sttp.client4.http4s -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.nio.charset.Charset -import cats.data.NonEmptyList import cats.effect.{Async, Deferred, Resource} import cats.implicits._ import cats.effect.implicits._ import fs2.io.file.Files import fs2.{Chunk, Stream} -import org.http4s.{ContentCoding, EntityBody, Request => Http4sRequest, Status} +import org.http4s.{EntityBody, Request => Http4sRequest, Status} import org.http4s import org.http4s.blaze.client.BlazeClientBuilder import org.http4s.client.Client import org.http4s.ember.client.EmberClientBuilder import org.typelevel.ci.CIString import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.http4s.Http4sBackend.EncodingHandler -import sttp.client4.httpclient.fs2.Fs2Compression import sttp.client4.impl.cats.CatsMonadAsyncError -import sttp.client4.internal.{throwNestedMultipartNotAllowed, BodyFromResponseAs, IOBufferSize, SttpFile} +import sttp.client4.internal.{BodyFromResponseAs, IOBufferSize, SttpFile} import sttp.model._ import sttp.monad.MonadError import sttp.client4.testing.StreamBackendStub import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException} import sttp.client4._ import sttp.client4.wrappers.FollowRedirectsBackend +import sttp.client4.compression.Compressor +import sttp.client4.impl.fs2.GZipFs2Compressor +import sttp.client4.impl.fs2.DeflateFs2Compressor +import sttp.client4.compression.CompressionHandlers +import sttp.client4.impl.fs2.GZipFs2Decompressor +import sttp.client4.impl.fs2.DeflateFs2Decompressor +import sttp.client4.compression.Decompressor // needs http4s using cats-effect class Http4sBackend[F[_]: Async]( client: Client[F], customizeRequest: Http4sRequest[F] => Http4sRequest[F], - customEncodingHandler: EncodingHandler[F] + compressionHandlers: CompressionHandlers[Fs2Streams[F], EntityBody[F]] ) extends StreamBackend[F, Fs2Streams[F]] { type R = Fs2Streams[F] with sttp.capabilities.Effect[F] + override def send[T](r: GenericRequest[T, R]): F[Response[T]] = adjustExceptions(r) { - val (entity, extraHeaders) = bodyToHttp4s(r, r.body) + val (body, contentLength) = Compressor.compressIfNeeded(r, compressionHandlers.compressors) + val (entity, extraHeaders) = bodyToHttp4s(body, contentLength) + val headers = + http4s.Headers { + val nonClHeaders = r.headers + .filterNot(_.is(HeaderNames.ContentLength)) + .map(h => http4s.Header.Raw(CIString(h.name), h.value)) + .toList + + val clHeader = contentLength + .map(cl => http4s.Header.Raw(CIString(HeaderNames.ContentLength), cl.toString)) + + nonClHeaders ++ clHeader + } ++ extraHeaders val request = r.httpVersion match { case Some(version) => Http4sRequest( method = methodToHttp4s(r.method), uri = http4s.Uri.unsafeFromString(r.uri.toString), - headers = - http4s.Headers(r.headers.map(h => http4s.Header.Raw(CIString(h.name), h.value)).toList) ++ extraHeaders, + headers = headers, body = entity.body, httpVersion = versionToHttp4s(version) ) @@ -50,8 +67,7 @@ class Http4sBackend[F[_]: Async]( Http4sRequest( method = methodToHttp4s(r.method), uri = http4s.Uri.unsafeFromString(r.uri.toString), - headers = - http4s.Headers(r.headers.map(h => http4s.Header.Raw(CIString(h.name), h.value)).toList) ++ extraHeaders, + headers = headers, body = entity.body ) } @@ -74,7 +90,7 @@ class Http4sBackend[F[_]: Async]( responseMetadata, Left( onFinalizeSignal( - decompressResponseBodyIfNotHead(r.method, response, r.autoDecompressionDisabled), + decompressResponseBodyIfNotHead(r.method, response, r.autoDecompressionEnabled), signalBodyComplete ) ) @@ -138,8 +154,8 @@ class Http4sBackend[F[_]: Async]( } private def bodyToHttp4s[R]( - r: GenericRequest[_, R], - body: GenericRequestBody[R] + body: GenericRequestBody[R], + contentLength: Option[Long] ): (http4s.Entity[F], http4s.Headers) = body match { case NoBody => (http4s.Entity(http4s.EmptyBody: http4s.EntityBody[F]), http4s.Headers.empty) @@ -147,10 +163,7 @@ class Http4sBackend[F[_]: Async]( case b: BasicBodyPart => (basicBodyToHttp4s(b), http4s.Headers.empty) case StreamBody(s) => - val cl = r.headers - .find(_.is(HeaderNames.ContentLength)) - .map(_.value.toLong) - (http4s.Entity(s.asInstanceOf[Stream[F, Byte]], cl), http4s.Headers.empty) + (http4s.Entity(s.asInstanceOf[Stream[F, Byte]], contentLength), http4s.Headers.empty) case m: MultipartBody[_] => val parts = m.parts.toVector.map(multipartToHttp4s) @@ -178,33 +191,19 @@ class Http4sBackend[F[_]: Async]( private def decompressResponseBodyIfNotHead[T]( m: Method, hr: http4s.Response[F], - disableAutoDecompression: Boolean + enableAutoDecompression: Boolean ): http4s.Response[F] = - if (m == Method.HEAD || disableAutoDecompression) hr else decompressResponseBody(hr) + if (m == Method.HEAD || !enableAutoDecompression) hr else decompressResponseBody(hr) private def decompressResponseBody(hr: http4s.Response[F]): http4s.Response[F] = { val body = hr.headers .get[http4s.headers.`Content-Encoding`] .filterNot(_ => hr.status.equals(Status.NoContent)) - .map(e => customEncodingHandler.orElse(EncodingHandler(standardEncodingHandler))(hr.body -> e.contentCoding)) + .map(e => Decompressor.decompressIfPossible(hr.body, e.contentCoding.coding, compressionHandlers.decompressors)) .getOrElse(hr.body) hr.copy(body = body) } - private def standardEncodingHandler: (EntityBody[F], ContentCoding) => EntityBody[F] = { - case (body, contentCoding) - if http4s.headers - .`Accept-Encoding`(NonEmptyList.of(http4s.ContentCoding.deflate)) - .satisfiedBy(contentCoding) => - body.through(Fs2Compression.inflateCheckHeader[F]) - case (body, contentCoding) - if http4s.headers - .`Accept-Encoding`(NonEmptyList.of(http4s.ContentCoding.gzip, http4s.ContentCoding.`x-gzip`)) - .satisfiedBy(contentCoding) => - body.through(fs2.compression.Compression[F].gunzip(4096)).flatMap(_.content) - case (_, contentCoding) => throw new UnsupportedEncodingException(s"Unsupported encoding: ${contentCoding.coding}") - } - private def bodyFromResponseAs(signalBodyComplete: F[Unit]) = new BodyFromResponseAs[F, http4s.Response[F], Nothing, EntityBody[F]] { override protected def withReplayableBody( @@ -268,51 +267,53 @@ class Http4sBackend[F[_]: Async]( } object Http4sBackend { - - type EncodingHandler[F[_]] = PartialFunction[(EntityBody[F], ContentCoding), EntityBody[F]] - - object EncodingHandler { - def apply[F[_]](f: (EntityBody[F], ContentCoding) => EntityBody[F]): EncodingHandler[F] = { case (b, c) => - f(b, c) - } - } + def defaultCompressionHandlers[F[_]: Async]: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + CompressionHandlers( + List(new GZipFs2Compressor[F, Fs2Streams[F]](), new DeflateFs2Compressor[F, Fs2Streams[F]]()), + List(new GZipFs2Decompressor, new DeflateFs2Decompressor) + ) def usingClient[F[_]: Async]( client: Client[F], customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Async[F]) ): StreamBackend[F, Fs2Streams[F]] = - FollowRedirectsBackend(new Http4sBackend[F](client, customizeRequest, customEncodingHandler)) + FollowRedirectsBackend(new Http4sBackend[F](client, customizeRequest, compressionHandlers(implicitly))) def usingBlazeClientBuilder[F[_]: Async]( blazeClientBuilder: BlazeClientBuilder[F], customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = - blazeClientBuilder.resource.map(c => usingClient(c, customizeRequest, customEncodingHandler)) + blazeClientBuilder.resource.map(c => usingClient(c, customizeRequest, compressionHandlers)) def usingDefaultBlazeClientBuilder[F[_]: Async]( customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = usingBlazeClientBuilder( BlazeClientBuilder[F], customizeRequest, - customEncodingHandler + compressionHandlers ) def usingEmberClientBuilder[F[_]: Async]( emberClientBuilder: EmberClientBuilder[F], customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = - emberClientBuilder.build.map(c => usingClient(c, customizeRequest, customEncodingHandler)) + emberClientBuilder.build.map(c => usingClient(c, customizeRequest, compressionHandlers)) def usingDefaultEmberClientBuilder[F[_]: Async]( customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Async[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Async[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = - usingEmberClientBuilder(EmberClientBuilder.default[F], customizeRequest, customEncodingHandler) + usingEmberClientBuilder(EmberClientBuilder.default[F], customizeRequest, compressionHandlers) /** Create a stub backend for testing, which uses the `F` response wrapper, and supports `Stream[F, Byte]` streaming. * diff --git a/http4s-ce2-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala b/http4s-ce2-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala index 81cfd1109b..4d29d8bdc2 100644 --- a/http4s-ce2-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala +++ b/http4s-ce2-backend/src/main/scala/sttp/client4/http4s/Http4sBackend.scala @@ -1,30 +1,32 @@ package sttp.client4.http4s -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.nio.charset.Charset -import cats.data.NonEmptyList import cats.effect.concurrent.MVar -import cats.effect.{Blocker, Concurrent, ConcurrentEffect, ContextShift, Resource} +import cats.effect.{Blocker, Concurrent, ConcurrentEffect, ContextShift, Resource, Sync} import cats.implicits._ import cats.effect.implicits._ import fs2.{Chunk, Stream} -import org.http4s.{ContentCoding, EntityBody, Request => Http4sRequest, Status} +import org.http4s.{EntityBody, Request => Http4sRequest, Status} import org.http4s import org.http4s.blaze.client.BlazeClientBuilder import org.http4s.client.Client import org.http4s.headers.`Content-Encoding` import org.typelevel.ci.CIString import sttp.capabilities.fs2.Fs2Streams -import sttp.client4.http4s.Http4sBackend.EncodingHandler import sttp.client4.impl.cats.CatsMonadAsyncError -import sttp.client4.impl.fs2.Fs2Compression -import sttp.client4.internal.{throwNestedMultipartNotAllowed, BodyFromResponseAs, IOBufferSize, SttpFile} +import sttp.client4.internal.{BodyFromResponseAs, IOBufferSize, SttpFile} import sttp.model._ import sttp.monad.MonadError import sttp.client4.testing.StreamBackendStub import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException} import sttp.client4._ import sttp.client4.wrappers.FollowRedirectsBackend +import sttp.client4.compression.Compressor +import sttp.client4.compression.CompressionHandlers +import sttp.client4.impl.fs2.GZipFs2Decompressor +import sttp.client4.impl.fs2.DeflateFs2Decompressor +import sttp.client4.compression.Decompressor import scala.concurrent.ExecutionContext @@ -32,19 +34,33 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( client: Client[F], blocker: Blocker, customizeRequest: Http4sRequest[F] => Http4sRequest[F], - customEncodingHandler: EncodingHandler[F] + compressionHandlers: CompressionHandlers[Fs2Streams[F], EntityBody[F]] ) extends StreamBackend[F, Fs2Streams[F]] { type R = Fs2Streams[F] with sttp.capabilities.Effect[F] + override def send[T](r: GenericRequest[T, R]): F[Response[T]] = adjustExceptions(r) { - val (entity, extraHeaders) = bodyToHttp4s(r, r.body) + val (body, contentLength) = Compressor.compressIfNeeded(r, compressionHandlers.compressors) + val (entity, extraHeaders) = bodyToHttp4s(body, contentLength) + val headers = + http4s.Headers { + val nonClHeaders = r.headers + .filterNot(_.is(HeaderNames.ContentLength)) + .map(h => http4s.Header.Raw(CIString(h.name), h.value)) + .toList + + val clHeader = contentLength + .map(cl => http4s.Header.Raw(CIString(HeaderNames.ContentLength), cl.toString)) + + nonClHeaders ++ clHeader + } ++ extraHeaders + val request = r.httpVersion match { case Some(version) => Http4sRequest( method = methodToHttp4s(r.method), uri = http4s.Uri.unsafeFromString(r.uri.toString), - headers = - http4s.Headers(r.headers.map(h => http4s.Header.Raw(CIString(h.name), h.value)).toList) ++ extraHeaders, + headers = headers, body = entity.body, httpVersion = versionToHttp4s(version) ) @@ -52,8 +68,7 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( Http4sRequest( method = methodToHttp4s(r.method), uri = http4s.Uri.unsafeFromString(r.uri.toString), - headers = - http4s.Headers(r.headers.map(h => http4s.Header.Raw(CIString(h.name), h.value)).toList) ++ extraHeaders, + headers = headers, body = entity.body ) } @@ -76,7 +91,7 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( responseMetadata, Left( onFinalizeSignal( - decompressResponseBodyIfNotHead(r.method, response, r.autoDecompressionDisabled), + decompressResponseBodyIfNotHead(r.method, response, r.autoDecompressionEnabled), signalBodyComplete ) ) @@ -138,17 +153,17 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( http4s.EntityEncoder.fileEncoder(blocker).toEntity(b.toFile) } - private def bodyToHttp4s(r: GenericRequest[_, R], body: GenericRequestBody[R]): (http4s.Entity[F], http4s.Headers) = + private def bodyToHttp4s( + body: GenericRequestBody[R], + contentLength: Option[Long] + ): (http4s.Entity[F], http4s.Headers) = body match { case NoBody => (http4s.Entity(http4s.EmptyBody: http4s.EntityBody[F]), http4s.Headers.empty) case b: BasicBodyPart => (basicBodyToHttp4s(b), http4s.Headers.empty) case StreamBody(s) => - val cl = r.headers - .find(_.is(HeaderNames.ContentLength)) - .map(_.value.toLong) - (http4s.Entity(s.asInstanceOf[Stream[F, Byte]], cl), http4s.Headers.empty) + (http4s.Entity(s.asInstanceOf[Stream[F, Byte]], contentLength), http4s.Headers.empty) case m: MultipartBody[_] => val parts = m.parts.toVector.map(multipartToHttp4s) @@ -176,33 +191,19 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( private def decompressResponseBodyIfNotHead[T]( m: Method, hr: http4s.Response[F], - disableAutoDecompression: Boolean + autoDecompressionEnabled: Boolean ): http4s.Response[F] = - if (m == Method.HEAD || disableAutoDecompression) hr else decompressResponseBody(hr) + if (m == Method.HEAD || !autoDecompressionEnabled) hr else decompressResponseBody(hr) private def decompressResponseBody(hr: http4s.Response[F]): http4s.Response[F] = { val body = hr.headers .get[`Content-Encoding`] .filterNot(_ => hr.status.equals(Status.NoContent)) - .map(e => customEncodingHandler.orElse(EncodingHandler(standardEncodingHandler))(hr.body -> e.contentCoding)) + .map(e => Decompressor.decompressIfPossible(hr.body, e.contentCoding.coding, compressionHandlers.decompressors)) .getOrElse(hr.body) hr.copy(body = body) } - private def standardEncodingHandler: (EntityBody[F], ContentCoding) => EntityBody[F] = { - case (body, contentCoding) - if http4s.headers - .`Accept-Encoding`(NonEmptyList.of(http4s.ContentCoding.deflate)) - .satisfiedBy(contentCoding) => - body.through(Fs2Compression.inflateCheckHeader[F]) - case (body, contentCoding) - if http4s.headers - .`Accept-Encoding`(NonEmptyList.of(http4s.ContentCoding.gzip, http4s.ContentCoding.`x-gzip`)) - .satisfiedBy(contentCoding) => - body.through(fs2.compression.gunzip(4096)).flatMap(_.content) - case (_, contentCoding) => throw new UnsupportedEncodingException(s"Unsupported encoding: ${contentCoding.coding}") - } - private def bodyFromResponseAs(signalBodyComplete: F[Unit]) = new BodyFromResponseAs[F, http4s.Response[F], Nothing, EntityBody[F]] { override protected def withReplayableBody( @@ -266,43 +267,42 @@ class Http4sBackend[F[_]: ConcurrentEffect: ContextShift]( } object Http4sBackend { - - type EncodingHandler[F[_]] = PartialFunction[(EntityBody[F], ContentCoding), EntityBody[F]] - object EncodingHandler { - def apply[F[_]](f: (EntityBody[F], ContentCoding) => EntityBody[F]): EncodingHandler[F] = { case (b, c) => - f(b, c) - } - } + def defaultCompressionHandlers[F[_]: Sync]: CompressionHandlers[Fs2Streams[F], Stream[F, Byte]] = + CompressionHandlers( + Compressor.default[Fs2Streams[F]], + List(new GZipFs2Decompressor, new DeflateFs2Decompressor) + ) def usingClient[F[_]: ConcurrentEffect: ContextShift]( client: Client[F], blocker: Blocker, customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Sync[F]) ): StreamBackend[F, Fs2Streams[F]] = - FollowRedirectsBackend( - new Http4sBackend[F](client, blocker, customizeRequest, customEncodingHandler) - ) + FollowRedirectsBackend(new Http4sBackend[F](client, blocker, customizeRequest, compressionHandlers(implicitly))) def usingBlazeClientBuilder[F[_]: ConcurrentEffect: ContextShift]( blazeClientBuilder: BlazeClientBuilder[F], blocker: Blocker, customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Sync[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = - blazeClientBuilder.resource.map(c => usingClient(c, blocker, customizeRequest, customEncodingHandler)) + blazeClientBuilder.resource.map(c => usingClient(c, blocker, customizeRequest, compressionHandlers)) def usingDefaultBlazeClientBuilder[F[_]: ConcurrentEffect: ContextShift]( blocker: Blocker, clientExecutionContext: ExecutionContext = ExecutionContext.global, customizeRequest: Http4sRequest[F] => Http4sRequest[F] = identity[Http4sRequest[F]] _, - customEncodingHandler: EncodingHandler[F] = PartialFunction.empty + compressionHandlers: Sync[F] => CompressionHandlers[Fs2Streams[F], EntityBody[F]] = + defaultCompressionHandlers[F](_: Sync[F]) ): Resource[F, StreamBackend[F, Fs2Streams[F]]] = usingBlazeClientBuilder( BlazeClientBuilder[F](clientExecutionContext), blocker, customizeRequest, - customEncodingHandler + compressionHandlers ) /** Create a stub backend for testing, which uses the `F` response wrapper, and supports `Stream[F, Byte]` streaming. diff --git a/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala b/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala index 3e54958067..7af1837595 100644 --- a/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala +++ b/okhttp-backend/monix/src/main/scala/sttp/client4/okhttp/monix/OkHttpMonixBackend.scala @@ -15,7 +15,6 @@ import sttp.capabilities.monix.MonixStreams import sttp.client4.impl.monix.{MonixSimpleQueue, MonixWebSockets, TaskMonadAsyncError} import sttp.client4.internal.ws.SimpleQueue import sttp.monad.MonadError -import sttp.client4.okhttp.OkHttpBackend.EncodingHandler import sttp.client4.okhttp.{BodyFromOkHttp, BodyToOkHttp, OkHttpAsyncBackend, OkHttpBackend} import sttp.client4.testing.WebSocketStreamBackendStub import sttp.client4._ @@ -23,11 +22,14 @@ import sttp.client4.wrappers.FollowRedirectsBackend import sttp.ws.{WebSocket, WebSocketFrame} import scala.concurrent.Future +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor class OkHttpMonixBackend private ( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[MonixStreams, InputStream], webSocketBufferCapacity: Option[Int] )(implicit s: Scheduler @@ -35,7 +37,7 @@ class OkHttpMonixBackend private ( client, TaskMonadAsyncError, closeClient, - customEncodingHandler + compressionHandlers ) with WebSocketStreamBackend[Task, MonixStreams] { override val streams: MonixStreams = MonixStreams @@ -113,21 +115,24 @@ class OkHttpMonixBackend private ( } object OkHttpMonixBackend { + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) + private def apply( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[Any, InputStream], webSocketBufferCapacity: Option[Int] )(implicit s: Scheduler ): WebSocketStreamBackend[Task, MonixStreams] = FollowRedirectsBackend( - new OkHttpMonixBackend(client, closeClient, customEncodingHandler, webSocketBufferCapacity)(s) + new OkHttpMonixBackend(client, closeClient, compressionHandlers, webSocketBufferCapacity)(s) ) def apply( options: BackendOptions = BackendOptions.Default, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity )(implicit s: Scheduler = Scheduler.global @@ -136,26 +141,26 @@ object OkHttpMonixBackend { OkHttpMonixBackend( OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, options), closeClient = true, - customEncodingHandler, + compressionHandlers, webSocketBufferCapacity )(s) ) def resource( options: BackendOptions = BackendOptions.Default, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity )(implicit s: Scheduler = Scheduler.global ): Resource[Task, WebSocketStreamBackend[Task, MonixStreams]] = - Resource.make(apply(options, customEncodingHandler, webSocketBufferCapacity))(_.close()) + Resource.make(apply(options, compressionHandlers, webSocketBufferCapacity))(_.close()) def usingClient( client: OkHttpClient, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity )(implicit s: Scheduler = Scheduler.global): WebSocketStreamBackend[Task, MonixStreams] = - OkHttpMonixBackend(client, closeClient = false, customEncodingHandler, webSocketBufferCapacity)(s) + OkHttpMonixBackend(client, closeClient = false, compressionHandlers, webSocketBufferCapacity)(s) /** Create a stub backend for testing, which uses the [[Task]] response wrapper, and supports `Observable[ByteBuffer]` * streaming. diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala index fd2c672deb..39993d9291 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpAsyncBackend.scala @@ -7,16 +7,17 @@ import okhttp3.{Call, Callback, OkHttpClient, Response => OkHttpResponse, WebSoc import sttp.capabilities.Streams import sttp.client4.internal.ws.{SimpleQueue, WebSocketEvent} import sttp.monad.syntax._ -import sttp.client4.okhttp.OkHttpBackend.EncodingHandler import sttp.client4.{ignore, GenericRequest, Response} import sttp.monad.{Canceler, MonadAsyncError} +import sttp.client4.compression.CompressionHandlers +import java.io.InputStream abstract class OkHttpAsyncBackend[F[_], S <: Streams[S], P]( client: OkHttpClient, _monad: MonadAsyncError[F], closeClient: Boolean, - customEncodingHandler: EncodingHandler -) extends OkHttpBackend[F, S, P](client, closeClient, customEncodingHandler) { + compressionHandlers: CompressionHandlers[P, InputStream] +) extends OkHttpBackend[F, S, P](client, closeClient, compressionHandlers) { override protected def sendRegular[T](request: GenericRequest[T, R]): F[Response[T]] = { val nativeRequest = convertRequest(request) diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala index e5f2d07235..14a2a1993a 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala @@ -1,8 +1,7 @@ package sttp.client4.okhttp -import java.io.{InputStream, UnsupportedEncodingException} +import java.io.InputStream import java.util.concurrent.TimeUnit -import java.util.zip.{GZIPInputStream, InflaterInputStream} import okhttp3.internal.http.HttpMethod import okhttp3.{ Authenticator, @@ -17,16 +16,18 @@ import sttp.capabilities.{Effect, Streams} import sttp.client4.BackendOptions.Proxy import sttp.client4.SttpClientException.ReadException import sttp.client4.internal.ws.SimpleQueue -import sttp.client4.okhttp.OkHttpBackend.EncodingHandler import sttp.client4._ import sttp.model._ import scala.collection.JavaConverters._ +import sttp.client4.compression.Compressor +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor abstract class OkHttpBackend[F[_], S <: Streams[S], P]( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler + compressionHandlers: CompressionHandlers[P, InputStream] ) extends GenericBackend[F, P] with Backend[F] { @@ -54,7 +55,8 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( val builder = new OkHttpRequest.Builder() .url(request.uri.toString) - val body = bodyToOkHttp(request.body, request.contentType, request.contentLength) + val (maybeCompressedBody, contentLength) = Compressor.compressIfNeeded(request, compressionHandlers.compressors) + val body = bodyToOkHttp(maybeCompressedBody, request.contentType, contentLength) builder.method( request.method.method, body.getOrElse { @@ -64,7 +66,13 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( } ) - request.headers.foreach(header => builder.addHeader(header.name, header.value)) + // the content-length header's value might have changed due to compression + request.headers.foreach(header => + if (!header.is(HeaderNames.ContentLength)) { + val _ = builder.addHeader(header.name, header.value) + } + ) + contentLength.foreach(cl => builder.addHeader(HeaderNames.ContentLength, cl.toString)) builder.build() } @@ -85,14 +93,11 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( if ( method != Method.HEAD && !res .code() - .equals(StatusCode.NoContent.code) && !request.autoDecompressionDisabled + .equals(StatusCode.NoContent.code) && request.autoDecompressionEnabled ) { encoding .filterNot(_.isEmpty) - .map(e => - customEncodingHandler // There is no PartialFunction.fromFunction in scala 2.12 - .orElse(EncodingHandler(standardEncoding))(res.body().byteStream() -> e) - ) + .map(e => Decompressor.decompressIfPossible(res.body().byteStream(), e, compressionHandlers.decompressors)) .getOrElse(res.body().byteStream()) } else { res.body().byteStream() @@ -110,12 +115,6 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( .flatMap(name => res.headers().values(name).asScala.map(Header(name, _))) .toList - private def standardEncoding: (InputStream, String) => InputStream = { - case (body, "gzip") => new GZIPInputStream(body) - case (body, "deflate") => new InflaterInputStream(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } - override def close(): F[Unit] = if (closeClient) { monad.eval(client.dispatcher().executorService().shutdown()) @@ -127,11 +126,6 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( object OkHttpBackend { val DefaultWebSocketBufferCapacity: Option[Int] = Some(1024) - type EncodingHandler = PartialFunction[(InputStream, String), InputStream] - - object EncodingHandler { - def apply(f: (InputStream, String) => InputStream): EncodingHandler = { case (i, s) => f(i, s) } - } private class ProxyAuthenticator(auth: BackendOptions.ProxyAuth) extends Authenticator { override def authenticate(route: Route, response: OkHttpResponse): OkHttpRequest = { diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala index 784ccf062b..04d56559c5 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpFutureBackend.scala @@ -5,7 +5,6 @@ import okhttp3.{MediaType, OkHttpClient, RequestBody => OkHttpRequestBody} import sttp.capabilities.{Streams, WebSockets} import sttp.client4.internal.NoStreams import sttp.client4.internal.ws.{FutureSimpleQueue, SimpleQueue} -import sttp.client4.okhttp.OkHttpBackend.EncodingHandler import sttp.client4.testing.WebSocketBackendStub import sttp.client4.wrappers.FollowRedirectsBackend import sttp.client4.{wrappers, BackendOptions, DefaultReadTimeout, WebSocketBackend} @@ -13,15 +12,18 @@ import sttp.monad.{FutureMonad, MonadError} import sttp.ws.WebSocket import scala.concurrent.{ExecutionContext, Future} +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor class OkHttpFutureBackend private ( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[Any, InputStream], webSocketBufferCapacity: Option[Int] )(implicit ec: ExecutionContext -) extends OkHttpAsyncBackend[Future, Nothing, WebSockets](client, new FutureMonad, closeClient, customEncodingHandler) +) extends OkHttpAsyncBackend[Future, Nothing, WebSockets](client, new FutureMonad, closeClient, compressionHandlers) with WebSocketBackend[Future] { override val streams: Streams[Nothing] = NoStreams @@ -43,36 +45,39 @@ class OkHttpFutureBackend private ( } object OkHttpFutureBackend { + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) + private def apply( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[Any, InputStream], webSocketBufferCapacity: Option[Int] )(implicit ec: ExecutionContext ): WebSocketBackend[Future] = wrappers.FollowRedirectsBackend( - new OkHttpFutureBackend(client, closeClient, customEncodingHandler, webSocketBufferCapacity) + new OkHttpFutureBackend(client, closeClient, compressionHandlers, webSocketBufferCapacity) ) def apply( options: BackendOptions = BackendOptions.Default, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity )(implicit ec: ExecutionContext = ExecutionContext.global): WebSocketBackend[Future] = OkHttpFutureBackend( OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, options), closeClient = true, - customEncodingHandler, + compressionHandlers, webSocketBufferCapacity ) def usingClient( client: OkHttpClient, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity )(implicit ec: ExecutionContext = ExecutionContext.global): WebSocketBackend[Future] = - OkHttpFutureBackend(client, closeClient = false, customEncodingHandler, webSocketBufferCapacity) + OkHttpFutureBackend(client, closeClient = false, compressionHandlers, webSocketBufferCapacity) /** Create a stub backend for testing, which uses the [[Future]] response wrapper, and doesn't support streaming. * diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpSyncBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpSyncBackend.scala index 2b9cc3f60e..4704d8d726 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpSyncBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpSyncBackend.scala @@ -4,7 +4,6 @@ import okhttp3.{MediaType, OkHttpClient, RequestBody => OkHttpRequestBody} import sttp.capabilities.{Streams, WebSockets} import sttp.client4.internal.NoStreams import sttp.client4.internal.ws.{SimpleQueue, SyncQueue, WebSocketEvent} -import sttp.client4.okhttp.OkHttpBackend.EncodingHandler import sttp.client4.testing.WebSocketSyncBackendStub import sttp.client4.{ ignore, @@ -24,13 +23,16 @@ import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.Duration import scala.concurrent.{blocking, Await, ExecutionContext, Future} +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Compressor +import sttp.client4.compression.Decompressor class OkHttpSyncBackend private ( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[Any, InputStream], webSocketBufferCapacity: Option[Int] -) extends OkHttpBackend[Identity, Nothing, WebSockets](client, closeClient, customEncodingHandler) +) extends OkHttpBackend[Identity, Nothing, WebSockets](client, closeClient, compressionHandlers) with WebSocketSyncBackend { private implicit val ec: ExecutionContext = ExecutionContext.global override val streams: Streams[Nothing] = NoStreams @@ -98,34 +100,37 @@ class OkHttpSyncBackend private ( } object OkHttpSyncBackend { + val DefaultCompressionHandlers: CompressionHandlers[Any, InputStream] = + CompressionHandlers(Compressor.default[Any], Decompressor.defaultInputStream) + private def apply( client: OkHttpClient, closeClient: Boolean, - customEncodingHandler: EncodingHandler, + compressionHandlers: CompressionHandlers[Any, InputStream], webSocketBufferCapacity: Option[Int] ): WebSocketSyncBackend = wrappers.FollowRedirectsBackend( - new OkHttpSyncBackend(client, closeClient, customEncodingHandler, webSocketBufferCapacity) + new OkHttpSyncBackend(client, closeClient, compressionHandlers, webSocketBufferCapacity) ) def apply( options: BackendOptions = BackendOptions.Default, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity ): WebSocketSyncBackend = OkHttpSyncBackend( OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, options), closeClient = true, - customEncodingHandler, + compressionHandlers, webSocketBufferCapacity ) def usingClient( client: OkHttpClient, - customEncodingHandler: EncodingHandler = PartialFunction.empty, + compressionHandlers: CompressionHandlers[Any, InputStream] = DefaultCompressionHandlers, webSocketBufferCapacity: Option[Int] = OkHttpBackend.DefaultWebSocketBufferCapacity ): WebSocketSyncBackend = - OkHttpSyncBackend(client, closeClient = false, customEncodingHandler, webSocketBufferCapacity) + OkHttpSyncBackend(client, closeClient = false, compressionHandlers, webSocketBufferCapacity) /** Create a stub backend for testing, which uses the [[Identity]] response wrapper, and doesn't support streaming. * diff --git a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/BodyToPekko.scala b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/BodyToPekko.scala index f320df3577..a157b5a624 100644 --- a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/BodyToPekko.scala +++ b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/BodyToPekko.scala @@ -13,18 +13,18 @@ import pekko.http.scaladsl.model.{ import pekko.stream.scaladsl.{Source, StreamConverters} import pekko.util.ByteString import sttp.capabilities.pekko.PekkoStreams -import sttp.client4.internal.throwNestedMultipartNotAllowed import sttp.client4._ -import sttp.model.{HeaderNames, Part} +import sttp.model.Part import scala.collection.immutable.Seq import scala.util.{Failure, Success, Try} +import sttp.client4.compression.Compressor private[pekkohttp] object BodyToPekko { def apply[R]( r: GenericRequest[_, R], - body: GenericRequestBody[R], - ar: HttpRequest + ar: HttpRequest, + compressors: List[Compressor[R]] ): Try[HttpRequest] = { def ctWithCharset(ct: ContentType, charset: String) = HttpCharsets @@ -32,7 +32,7 @@ private[pekkohttp] object BodyToPekko { .map(hc => ContentType.apply(ct.mediaType, () => hc)) .getOrElse(ct) - def contentLength = r.headers.find(_.is(HeaderNames.ContentLength)).flatMap(h => Try(h.value.toLong).toOption) + val (body, contentLength) = Compressor.compressIfNeeded(r, compressors) def toBodyPart(mp: Part[BodyPart[_]]): Try[PekkoMultipart.FormData.BodyPart] = { def streamPartEntity(contentType: ContentType, s: PekkoStreams.BinaryStream) = diff --git a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoCompressor.scala b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoCompressor.scala new file mode 100644 index 0000000000..2672d121b4 --- /dev/null +++ b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoCompressor.scala @@ -0,0 +1,32 @@ +package sttp.client4.pekkohttp + +import org.apache.pekko.util.ByteString +import org.apache.pekko.stream.scaladsl.Compression +import sttp.capabilities.pekko.PekkoStreams +import org.apache.pekko.stream.scaladsl.Source +import sttp.client4._ +import sttp.client4.compression.DeflateDefaultCompressor +import sttp.client4.compression.GZipDefaultCompressor +import sttp.client4.compression.Compressor +import org.apache.pekko.stream.scaladsl.StreamConverters +import org.apache.pekko.stream.scaladsl.FileIO + +trait PekkoCompressor extends Compressor[PekkoStreams] { + override abstract def apply[R2 <: PekkoStreams](body: GenericRequestBody[R2]): GenericRequestBody[PekkoStreams] = + body match { + case InputStreamBody(b, _) => StreamBody(PekkoStreams)(compressStream(StreamConverters.fromInputStream(() => b))) + case StreamBody(b) => StreamBody(PekkoStreams)(compressStream(b.asInstanceOf[Source[ByteString, Any]])) + case FileBody(f, _) => StreamBody(PekkoStreams)(compressStream(FileIO.fromPath(f.toPath))) + case _ => super.apply(body) + } + + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] +} + +object GZipPekkoCompressor extends GZipDefaultCompressor[PekkoStreams] with PekkoCompressor { + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.gzip) +} + +object DeflatePekkoCompressor extends DeflateDefaultCompressor[PekkoStreams] with PekkoCompressor { + def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.deflate) +} diff --git a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoHttpBackend.scala b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoHttpBackend.scala index d3937216ab..7e1c3d3d5b 100644 --- a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoHttpBackend.scala +++ b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/PekkoHttpBackend.scala @@ -1,22 +1,19 @@ package sttp.client4.pekkohttp import java.io.UnsupportedEncodingException -import org.apache.pekko -import pekko.{Done, NotUsed} -import pekko.actor.{ActorSystem, CoordinatedShutdown} -import pekko.event.LoggingAdapter -import pekko.http.scaladsl.coding.Coders -import pekko.http.scaladsl.model.headers.{BasicHttpCredentials, HttpEncoding, HttpEncodings} -import pekko.http.scaladsl.model.ws.{InvalidUpgradeResponse, Message, ValidUpgrade, WebSocketRequest} -import pekko.http.scaladsl.model.{StatusCode => _, _} -import pekko.http.scaladsl.settings.ConnectionPoolSettings -import pekko.http.scaladsl.{ClientTransport, Http, HttpsConnectionContext} -import pekko.stream.Materializer -import pekko.stream.scaladsl.{Flow, Sink} +import org.apache.pekko.{Done, NotUsed} +import org.apache.pekko.actor.{ActorSystem, CoordinatedShutdown} +import org.apache.pekko.event.LoggingAdapter +import org.apache.pekko.http.scaladsl.coding.Coders +import org.apache.pekko.http.scaladsl.model.headers.{BasicHttpCredentials, HttpEncoding, HttpEncodings} +import org.apache.pekko.http.scaladsl.model.ws.{InvalidUpgradeResponse, Message, ValidUpgrade, WebSocketRequest} +import org.apache.pekko.http.scaladsl.model.{StatusCode => _, _} +import org.apache.pekko.http.scaladsl.settings.ConnectionPoolSettings +import org.apache.pekko.http.scaladsl.{ClientTransport, Http, HttpsConnectionContext} +import org.apache.pekko.stream.Materializer +import org.apache.pekko.stream.scaladsl.{Flow, Sink} import sttp.capabilities.pekko.PekkoStreams import sttp.capabilities.{Effect, WebSockets} -import sttp.client4 -import sttp.client4.pekkohttp.PekkoHttpBackend.EncodingHandler import sttp.client4.testing.WebSocketStreamBackendStub import sttp.client4._ import sttp.client4.wrappers.FollowRedirectsBackend @@ -24,6 +21,8 @@ import sttp.model.{ResponseMetadata, StatusCode} import sttp.monad.{FutureMonad, MonadError} import scala.concurrent.{ExecutionContext, Future, Promise} +import sttp.client4.compression.CompressionHandlers +import sttp.client4.compression.Decompressor class PekkoHttpBackend private ( actorSystem: ActorSystem, @@ -35,7 +34,7 @@ class PekkoHttpBackend private ( customizeRequest: HttpRequest => HttpRequest, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse, - customEncodingHandler: EncodingHandler + compressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] ) extends WebSocketStreamBackend[Future, PekkoStreams] { type R = PekkoStreams with WebSockets with Effect[Future] @@ -53,7 +52,7 @@ class PekkoHttpBackend private ( private def sendRegular[T](r: GenericRequest[T, R]): Future[Response[T]] = Future - .fromTry(ToPekko.request(r).flatMap(BodyToPekko(r, r.body, _))) + .fromTry(ToPekko.request(r).flatMap(BodyToPekko(r, _, compressionHandlers.compressors))) .map(customizeRequest) .flatMap(request => http @@ -134,23 +133,21 @@ class PekkoHttpBackend private ( val body = bodyFromPekko( r.response, responseMetadata, - wsFlow.map(Right(_)).getOrElse(Left(decodePekkoResponse(hr, r.autoDecompressionDisabled))) + wsFlow.map(Right(_)).getOrElse(Left(decodePekkoResponse(hr, r.autoDecompressionEnabled))) ) - body.map(client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata)) + body.map(sttp.client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata)) } // http://doc.akka.io/docs/akka-http/10.0.7/scala/http/common/de-coding.html - private def decodePekkoResponse(response: HttpResponse, disableAutoDecompression: Boolean): HttpResponse = - if (!response.status.allowsEntity() || disableAutoDecompression) response - else customEncodingHandler.orElse(EncodingHandler(standardEncoding)).apply(response -> response.encoding) - - private def standardEncoding: (HttpResponse, HttpEncoding) => HttpResponse = { - case (body, HttpEncodings.gzip) => Coders.Gzip.decodeMessage(body) - case (body, HttpEncodings.deflate) => Coders.Deflate.decodeMessage(body) - case (body, HttpEncodings.identity) => Coders.NoCoding.decodeMessage(body) - case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce") - } + private def decodePekkoResponse(response: HttpResponse, enableAutoDecompression: Boolean): HttpResponse = + if (!response.status.allowsEntity() || !enableAutoDecompression) response + else + response.encoding match { + case HttpEncodings.identity => response + case encoding: HttpEncoding => + Decompressor.decompressIfPossible(response, encoding.value, compressionHandlers.decompressors) + } private def adjustExceptions[T](request: GenericRequest[_, _])(t: => Future[T]): Future[T] = SttpClientException.adjustExceptions(monad)(t)(FromPekko.exception(request, _)) @@ -166,12 +163,11 @@ class PekkoHttpBackend private ( } object PekkoHttpBackend { - type EncodingHandler = PartialFunction[(HttpResponse, HttpEncoding), HttpResponse] - object EncodingHandler { - def apply(f: (HttpResponse, HttpEncoding) => HttpResponse): EncodingHandler = { case (body, encoding) => - f(body, encoding) - } - } + val DefaultCompressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] = + CompressionHandlers( + List(GZipPekkoCompressor, DeflatePekkoCompressor), + List(GZipPekkoDecompressor, DeflatePekkoDecompressor) + ) private def make( actorSystem: ActorSystem, @@ -183,7 +179,7 @@ object PekkoHttpBackend { customizeRequest: HttpRequest => HttpRequest, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] ): WebSocketStreamBackend[Future, PekkoStreams] = FollowRedirectsBackend( new PekkoHttpBackend( @@ -196,7 +192,7 @@ object PekkoHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) ) @@ -212,7 +208,7 @@ object PekkoHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, PekkoStreams] = { @@ -228,7 +224,7 @@ object PekkoHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) } @@ -247,7 +243,7 @@ object PekkoHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, PekkoStreams] = @@ -259,7 +255,7 @@ object PekkoHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) /** @param actorSystem @@ -276,7 +272,7 @@ object PekkoHttpBackend { customizeRequest: HttpRequest => HttpRequest = identity, customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity, customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r, - customEncodingHandler: EncodingHandler = PartialFunction.empty + compressionHandlers: CompressionHandlers[PekkoStreams, HttpResponse] = DefaultCompressionHandlers )(implicit ec: Option[ExecutionContext] = None ): WebSocketStreamBackend[Future, PekkoStreams] = @@ -290,7 +286,7 @@ object PekkoHttpBackend { customizeRequest, customizeWebsocketRequest, customizeResponse, - customEncodingHandler + compressionHandlers ) /** Create a stub backend for testing, which uses the [[Future]] response wrapper, and doesn't support streaming. diff --git a/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/pekkoDecompressors.scala b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/pekkoDecompressors.scala new file mode 100644 index 0000000000..026cf79127 --- /dev/null +++ b/pekko-http-backend/src/main/scala/sttp/client4/pekkohttp/pekkoDecompressors.scala @@ -0,0 +1,16 @@ +package sttp.client4.pekkohttp + +import sttp.client4.compression.Decompressor +import sttp.model.Encodings +import org.apache.pekko.http.scaladsl.model.HttpResponse +import org.apache.pekko.http.scaladsl.coding.Coders + +object GZipPekkoDecompressor extends Decompressor[HttpResponse] { + override val encoding: String = Encodings.Gzip + override def apply(body: HttpResponse): HttpResponse = Coders.Gzip.decodeMessage(body) +} + +object DeflatePekkoDecompressor extends Decompressor[HttpResponse] { + override val encoding: String = Encodings.Deflate + override def apply(body: HttpResponse): HttpResponse = Coders.Deflate.decodeMessage(body) +}