diff --git a/core/src/main/scala/sttp/client4/RequestOptions.scala b/core/src/main/scala/sttp/client4/RequestOptions.scala index 32c3805e6c..8ed0cb1f99 100644 --- a/core/src/main/scala/sttp/client4/RequestOptions.scala +++ b/core/src/main/scala/sttp/client4/RequestOptions.scala @@ -1,10 +1,13 @@ package sttp.client4 +import sttp.client4.internal.ContentEncoding + import scala.concurrent.duration.Duration case class RequestOptions( followRedirects: Boolean, readTimeout: Duration, // TODO: Use FiniteDuration while migrating to sttp-4 maxRedirects: Int, - redirectToGet: Boolean + redirectToGet: Boolean, + encoding: List[ContentEncoding] = List.empty ) diff --git a/core/src/main/scala/sttp/client4/SttpClientException.scala b/core/src/main/scala/sttp/client4/SttpClientException.scala index 2290950030..1410074436 100644 --- a/core/src/main/scala/sttp/client4/SttpClientException.scala +++ b/core/src/main/scala/sttp/client4/SttpClientException.scala @@ -28,6 +28,8 @@ object SttpClientException extends SttpClientExceptionExtensions { class TimeoutException(request: GenericRequest[_, _], cause: Exception) extends ReadException(request, cause) + class EncodingException(request: GenericRequest[_, _], cause: Exception) extends SttpClientException(request, cause) + def adjustExceptions[F[_], T]( monadError: MonadError[F] )(t: => F[T])(usingFn: Exception => Option[Exception]): F[T] = diff --git a/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala b/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala new file mode 100644 index 0000000000..8dcc9b64fb --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/ContentEncoding.scala @@ -0,0 +1,32 @@ +package sttp.client4.internal + +sealed trait ContentEncoding { + def name: String +} + +object ContentEncoding { + + val gzip = Gzip() + val deflate = Deflate() + + case class Gzip() extends ContentEncoding { + override def name: String = "gzip" + } + + case class Compress() extends ContentEncoding { + override def name: String = "compress" + } + + case class Deflate() extends ContentEncoding { + override def name: String = "deflate" + } + + case class Br() extends ContentEncoding { + override def name: String = "br" + } + + case class Zstd() extends ContentEncoding { + override def name: String = "zstd" + } + +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala new file mode 100644 index 0000000000..9f3c15d435 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/ContentCodec.scala @@ -0,0 +1,78 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.{Deflate, Gzip} +import sttp.client4.internal.encoders.EncoderError.UnsupportedEncoding +import sttp.client4.{BasicBodyPart, ByteArrayBody, ByteBufferBody, FileBody, InputStreamBody, StringBody} +import sttp.model.MediaType + +import scala.annotation.tailrec + +trait ContentCodec[C <: ContentEncoding] { + + type BodyWithLength = (BasicBodyPart, Int) + + def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] + + def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] + + def encoding: C + +} + +abstract class AbstractContentCodec[C <: ContentEncoding] extends ContentCodec[C] { + + override def encode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = + body match { + case StringBody(s, encoding, ct) => encode(s.getBytes(encoding), ct) + case ByteArrayBody(b, ct) => encode(b, ct) + case ByteBufferBody(b, ct) => encode(b.array(), ct) + case InputStreamBody(b, ct) => encode(b.readAllBytes(), ct) + case FileBody(f, ct) => encode(f.readAsByteArray, ct) + } + + private def encode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = + encode(bytes).map(r => ByteArrayBody(r, ct) -> r.length) + + override def decode(body: BasicBodyPart): Either[EncoderError, BodyWithLength] = body match { + case StringBody(s, encoding, ct) => decode(s.getBytes(encoding), ct) + case ByteArrayBody(b, ct) => decode(b, ct) + case ByteBufferBody(b, ct) => decode(b.array(), ct) + case InputStreamBody(b, ct) => decode(b.readAllBytes(), ct) + case FileBody(f, ct) => decode(f.readAsByteArray, ct) + } + + private def decode(bytes: Array[Byte], ct: MediaType): Either[EncoderError, BodyWithLength] = + decode(bytes).map(r => ByteArrayBody(r, ct) -> r.length) + + def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] + def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] +} + +object ContentCodec { + + private val gzipCodec = new GzipContentCodec + + private val deflateCodec = new DeflateContentCodec + + def encode(b: BasicBodyPart, codec: List[ContentEncoding]): Either[EncoderError, (BasicBodyPart, Int)] = + foldLeftInEither(codec, b -> 0) { case ((l, _), r) => + r match { + case _: Gzip => gzipCodec.encode(l) + case _: Deflate => deflateCodec.encode(l) + case e => Left(UnsupportedEncoding(e)) + } + } + + @tailrec + private def foldLeftInEither[T, R, E](elems: List[T], zero: R)(f: (R, T) => Either[E, R]): Either[E, R] = + elems match { + case Nil => Right[E, R](zero) + case head :: tail => + f(zero, head) match { + case l: Left[E, R] => l + case Right(v) => foldLeftInEither(tail, v)(f) + } + } + +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala new file mode 100644 index 0000000000..d3f26bf65a --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/DeflateContentCodec.scala @@ -0,0 +1,36 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.Deflate + +import java.io.ByteArrayOutputStream +import java.util.zip.{Deflater, Inflater} +import scala.util.{Try, Using} + +class DeflateContentCodec extends AbstractContentCodec[Deflate] { + + override def encode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] = + Try { + val deflater: Deflater = new Deflater() + deflater.setInput(bytes) + deflater.finish() + val compressedData = new Array[Byte](bytes.length * 2) + val count: Int = deflater.deflate(compressedData) + compressedData.take(count) + }.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage)) + + override def decode(bytes: Array[Byte]): Either[EncoderError, Array[Byte]] = + Using(new ByteArrayOutputStream()) { bos => + val buf = new Array[Byte](1024) + val decompresser = new Inflater() + decompresser.setInput(bytes, 0, bytes.length) + while (!decompresser.finished) { + val resultLength = decompresser.inflate(buf) + bos.write(buf, 0, resultLength) + } + decompresser.end() + bos.toByteArray + }.toEither.left.map(ex => EncoderError.EncodingFailure(encoding, ex.getMessage)) + + override def encoding: Deflate = ContentEncoding.deflate +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala b/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala new file mode 100644 index 0000000000..36e75d1c40 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/EncoderError.scala @@ -0,0 +1,20 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding + +import scala.util.control.NoStackTrace + +sealed trait EncoderError extends Exception with NoStackTrace { + def reason: String +} + +object EncoderError { + case class UnsupportedEncoding(encoding: ContentEncoding) extends EncoderError { + override def reason: String = s"${encoding.name} is unsupported with this body" + } + + case class EncodingFailure(encoding: ContentEncoding, msg: String) extends EncoderError { + + override def reason: String = s"Can`t encode $encoding for body $msg" + } +} diff --git a/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala b/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala new file mode 100644 index 0000000000..a2c0913fa6 --- /dev/null +++ b/core/src/main/scala/sttp/client4/internal/encoders/GzipContentCodec.scala @@ -0,0 +1,29 @@ +package sttp.client4.internal.encoders + +import sttp.client4.internal.ContentEncoding +import sttp.client4.internal.ContentEncoding.Gzip +import sttp.client4.internal.encoders.EncoderError.EncodingFailure + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.util.Using + +class GzipContentCodec extends AbstractContentCodec[Gzip] { + + override def encode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = + Using(new ByteArrayOutputStream) { baos => + Using(new GZIPOutputStream(baos)) { gzos => + gzos.write(bytes) + gzos.finish() + baos.toByteArray + } + }.flatMap(identity).toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage)) + + override def decode(bytes: Array[Byte]): Either[EncodingFailure, Array[Byte]] = + Using(new GZIPInputStream(new ByteArrayInputStream(bytes))) { b => + b.readAllBytes() + }.toEither.left.map(ex => EncodingFailure(encoding, ex.getMessage)) + + override def encoding: Gzip = ContentEncoding.gzip + +} diff --git a/core/src/main/scala/sttp/client4/request.scala b/core/src/main/scala/sttp/client4/request.scala index 1077b56d29..51b72ab449 100644 --- a/core/src/main/scala/sttp/client4/request.scala +++ b/core/src/main/scala/sttp/client4/request.scala @@ -2,6 +2,8 @@ package sttp.client4 import sttp.model.{Header, Method, Part, RequestMetadata, Uri} import sttp.capabilities.{Effect, Streams, WebSockets} +import sttp.client4.SttpClientException.EncodingException +import sttp.client4.internal.encoders.ContentCodec import sttp.client4.internal.{ToCurlConverter, ToRfc2616Converter} import sttp.shared.Identity @@ -144,7 +146,8 @@ case class Request[T]( * Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown * unchanged. */ - def send[F[_]](backend: Backend[F]): F[Response[T]] = backend.send(this) + def send[F[_]](backend: Backend[F]): F[Response[T]] = + backend.send(this) /** Sends the request synchronously, using the given backend. * @@ -156,7 +159,8 @@ case class Request[T]( * Known exceptions are converted by backends to one of [[SttpClientException]]. Other exceptions are thrown * unchanged. */ - def send(backend: SyncBackend): Response[T] = backend.send(this) + def send(backend: SyncBackend): Response[T] = + backend.send(this) } object Request { diff --git a/core/src/main/scala/sttp/client4/requestBuilder.scala b/core/src/main/scala/sttp/client4/requestBuilder.scala index 36d7f0d22d..869941ab81 100644 --- a/core/src/main/scala/sttp/client4/requestBuilder.scala +++ b/core/src/main/scala/sttp/client4/requestBuilder.scala @@ -1,8 +1,6 @@ package sttp.client4 -import sttp.client4.internal.SttpFile -import sttp.client4.internal.Utf8 -import sttp.client4.internal.contentTypeWithCharset +import sttp.client4.internal.{contentTypeWithCharset, ContentEncoding, SttpFile, Utf8} import sttp.client4.logging.LoggingOptions import sttp.client4.wrappers.DigestAuthenticationBackend import sttp.model.HasHeaders @@ -74,6 +72,10 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R] header(HeaderNames.ContentType, contentTypeWithCharset(ct, encoding)) def contentLength(l: Long): PR = header(HeaderNames.ContentLength, l.toString) + def contentEncoding(encoding: ContentEncoding): PR = + header(HeaderNames.ContentEncoding, encoding.name, DuplicateHeaderBehavior.Add) + .withOptions(options.copy(encoding = options.encoding :+ encoding)) + /** Adds the given header to the headers of this request. If a header with the same name already exists, the default * is to replace it with the given one. * @@ -222,6 +224,8 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R] */ def body(fs: Seq[(String, String)], encoding: String): PR = formDataBody(fs, encoding) + def body(b: BasicBody): PR = copyWithBody(b) + def multipartBody(ps: Seq[Part[BasicBodyPart]]): PR = copyWithBody(BasicMultipartBody(ps)) def multipartBody(p1: Part[BasicBodyPart], ps: Part[BasicBodyPart]*): PR = copyWithBody( @@ -254,8 +258,8 @@ trait PartialRequestBuilder[+PR <: PartialRequestBuilder[PR, R], +R] def followRedirects(fr: Boolean): PR = withOptions(options.copy(followRedirects = fr)) def maxRedirects(n: Int): PR = - if (n <= 0) withOptions(options.copy(followRedirects = false)) - else withOptions(options.copy(followRedirects = true, maxRedirects = n)) + if (n <= 0) withOptions(options.copy(followRedirects = false)) + else withOptions(options.copy(followRedirects = true, maxRedirects = n)) /** When a POST or PUT request is redirected, should the redirect be a POST/PUT as well (with the original body), or * should the request be converted to a GET without a body. diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala index 7bdc750acf..eeca5e6bae 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala @@ -2,8 +2,9 @@ package sttp.client4.internal.httpclient import sttp.capabilities.Streams import sttp.client4.internal.SttpToJavaConverters.toJavaSupplier -import sttp.client4.internal.{throwNestedMultipartNotAllowed, Utf8} +import sttp.client4.internal.{throwNestedMultipartNotAllowed, ContentEncoding, Utf8} import sttp.client4._ +import sttp.client4.internal.encoders.ContentCodec import sttp.model.{Header, HeaderNames, Part} import sttp.monad.MonadError import sttp.monad.syntax._ @@ -25,26 +26,44 @@ private[client4] trait BodyToHttpClient[F[_], S] { builder: HttpRequest.Builder, contentType: Option[String] ): F[BodyPublisher] = { - val body = request.body match { - case NoBody => BodyPublishers.noBody().unit - case StringBody(b, _, _) => BodyPublishers.ofString(b).unit - case ByteArrayBody(b, _) => BodyPublishers.ofByteArray(b).unit - case ByteBufferBody(b, _) => + val body: F[BodyPublisher] = request.options.encoding -> request.body match { + case (_, NoBody) => BodyPublishers.noBody().unit + case (Nil, StringBody(b, _, _)) => BodyPublishers.ofString(b).unit + case (Nil, ByteArrayBody(b, _)) => BodyPublishers.ofByteArray(b).unit + case (Nil, ByteBufferBody(b, _)) => if (b.hasArray) BodyPublishers.ofByteArray(b.array(), 0, b.limit()).unit else { val a = new Array[Byte](b.remaining()); b.get(a); BodyPublishers.ofByteArray(a).unit } - case InputStreamBody(b, _) => BodyPublishers.ofInputStream(toJavaSupplier(() => b)).unit - case FileBody(f, _) => BodyPublishers.ofFile(f.toFile.toPath).unit - case StreamBody(s) => streamToPublisher(s.asInstanceOf[streams.BinaryStream]) - case m: MultipartBody[_] => + case (Nil, InputStreamBody(b, _)) => BodyPublishers.ofInputStream(toJavaSupplier(() => b)).unit + case (Nil, FileBody(f, _)) => BodyPublishers.ofFile(f.toFile.toPath).unit + case (_, StreamBody(s)) => streamToPublisher(s.asInstanceOf[streams.BinaryStream]) + case (_, m: MultipartBody[_]) => val multipartBodyPublisher = multipartBody(m.parts) val baseContentType = contentType.getOrElse("multipart/form-data") builder.header(HeaderNames.ContentType, s"$baseContentType; boundary=${multipartBodyPublisher.getBoundary}") multipartBodyPublisher.build().unit + + case (coders, r: BasicBodyPart) => + ContentCodec.encode(r, coders) match { + case Left(err) => monad.error(err) + case Right(newBody) => + val (body, length) = newBody + val newRequest = request + .contentLength(length.toLong) + .body(body) + .withOptions(request.options.copy(encoding = Nil)) + apply[T](newRequest, builder, contentType) // can we avoid recursion? + } } (request.contentLength: Option[Long]) match { - case None => body - case Some(cl) => body.map(b => withKnownContentLength(b, cl)) + case None => body + case Some(cl) => + body.map { b => + if (b.contentLength() >= 0) // see BodyPublisher.contentLength docs + withKnownContentLength(b, b.contentLength()) + else + withKnownContentLength(b, cl) + } } } diff --git a/core/src/test/scala/sttp/client4/testing/HttpTest.scala b/core/src/test/scala/sttp/client4/testing/HttpTest.scala index 90c2c62732..f7ea7ebf43 100644 --- a/core/src/test/scala/sttp/client4/testing/HttpTest.scala +++ b/core/src/test/scala/sttp/client4/testing/HttpTest.scala @@ -3,15 +3,18 @@ package sttp.client4.testing import org.scalatest._ import org.scalatest.freespec.AsyncFreeSpec import org.scalatest.matchers.should.Matchers -import sttp.client4.internal.{Iso88591, Utf8} +import sttp.client4.internal.{ContentEncoding, Iso88591, Utf8} import sttp.client4.testing.HttpTest.endpoint import sttp.client4._ +import sttp.client4.internal.encoders.EncoderError.EncodingFailure +import sttp.client4.internal.encoders.{DeflateContentCodec, GzipContentCodec} import sttp.model.StatusCode import sttp.monad.MonadError import sttp.monad.syntax._ import java.io.{ByteArrayInputStream, UnsupportedEncodingException} import java.nio.ByteBuffer +import scala.Right import scala.concurrent.Future import scala.concurrent.duration._ @@ -21,6 +24,7 @@ trait HttpTest[F[_]] with Matchers with ToFutureWrapper with OptionValues + with EitherValues with HttpTestExtensions[F] with AsyncRetries { @@ -420,6 +424,52 @@ trait HttpTest[F[_]] req.send(backend).toFuture().map(resp => resp.code shouldBe StatusCode.Ok) } + "should compress request body gzip" in { + val codec = new GzipContentCodec + val req = basicRequest.contentEncoding(ContentEncoding.gzip) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codec.decode(resp.body) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + + "should compress request body deflate" in { + val codec = new DeflateContentCodec + val req = basicRequest.contentEncoding(ContentEncoding.deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codec.decode(resp.body) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + + "should compress request body multiple codecs" in { + val codecDeflate = new DeflateContentCodec + val codecGzip = new GzipContentCodec + val req = basicRequest + .contentEncoding(ContentEncoding.gzip) + .contentEncoding(ContentEncoding.deflate) + .response(asByteArrayAlways) + .post(uri"$endpoint/echo/exact") + .body("I`m not compressed") + req.send(backend).toFuture().map{resp => + resp.code shouldBe StatusCode.Ok + val res = codecDeflate.decode(resp.body) + .flatMap(b => codecGzip.decode(b)) + res.isRight shouldBe true + res.right.value shouldBe "I`m not compressed".getBytes() + } + } + if (supportsCustomContentEncoding) { "decompress using custom content encoding" in { val req =