Skip to content

Commit

Permalink
Add request body compression support (#2381)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw authored Dec 31, 2024
1 parent 28a7b50 commit 1a30f3a
Show file tree
Hide file tree
Showing 71 changed files with 1,431 additions and 585 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package sttp.client4.akkahttp

import akka.util.ByteString
import akka.stream.scaladsl.Compression
import sttp.capabilities.akka.AkkaStreams
import akka.stream.scaladsl.Source
import sttp.client4._
import sttp.client4.compression.DeflateDefaultCompressor
import sttp.client4.compression.GZipDefaultCompressor
import sttp.client4.compression.Compressor
import akka.stream.scaladsl.StreamConverters
import akka.stream.scaladsl.FileIO

trait AkkaCompressor extends Compressor[AkkaStreams] {
override abstract def apply[R2 <: AkkaStreams](body: GenericRequestBody[R2]): GenericRequestBody[AkkaStreams] =
body match {
case InputStreamBody(b, _) => StreamBody(AkkaStreams)(compressStream(StreamConverters.fromInputStream(() => b)))
case StreamBody(b) => StreamBody(AkkaStreams)(compressStream(b.asInstanceOf[Source[ByteString, Any]]))
case FileBody(f, _) => StreamBody(AkkaStreams)(compressStream(FileIO.fromPath(f.toPath)))
case _ => super.apply(body)
}

def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any]
}

object GZipAkkaCompressor extends GZipDefaultCompressor[AkkaStreams] with AkkaCompressor {
def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.gzip)
}

object DeflateAkkaCompressor extends DeflateDefaultCompressor[AkkaStreams] with AkkaCompressor {
def compressStream(stream: Source[ByteString, Any]): Source[ByteString, Any] = stream.via(Compression.deflate)
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package sttp.client4.akkahttp

import java.io.UnsupportedEncodingException
import akka.{Done, NotUsed}
import akka.actor.{ActorSystem, CoordinatedShutdown}
import akka.event.LoggingAdapter
import akka.http.scaladsl.coding.Coders
import akka.http.scaladsl.model.headers.{BasicHttpCredentials, HttpEncoding, HttpEncodings}
import akka.http.scaladsl.model.ws.{InvalidUpgradeResponse, Message, ValidUpgrade, WebSocketRequest}
import akka.http.scaladsl.model.{StatusCode => _, _}
Expand All @@ -14,13 +12,13 @@ import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Sink}
import sttp.capabilities.akka.AkkaStreams
import sttp.capabilities.{Effect, WebSockets}
import sttp.client4
import sttp.client4.akkahttp.AkkaHttpBackend.EncodingHandler
import sttp.client4.testing.WebSocketStreamBackendStub
import sttp.client4._
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.model.{ResponseMetadata, StatusCode}
import sttp.monad.{FutureMonad, MonadError}
import sttp.client4.compression.CompressionHandlers
import sttp.client4.compression.Decompressor

import scala.concurrent.{ExecutionContext, Future, Promise}

Expand All @@ -34,7 +32,7 @@ class AkkaHttpBackend private (
customizeRequest: HttpRequest => HttpRequest,
customizeWebsocketRequest: WebSocketRequest => WebSocketRequest,
customizeResponse: (HttpRequest, HttpResponse) => HttpResponse,
customEncodingHandler: EncodingHandler
compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse]
) extends WebSocketStreamBackend[Future, AkkaStreams] {
type R = AkkaStreams with WebSockets with Effect[Future]

Expand All @@ -52,7 +50,7 @@ class AkkaHttpBackend private (

private def sendRegular[T](r: GenericRequest[T, R]): Future[Response[T]] =
Future
.fromTry(ToAkka.request(r).flatMap(BodyToAkka(r, r.body, _)))
.fromTry(ToAkka.request(r).flatMap(BodyToAkka(r, _, compressionHandlers.compressors)))
.map(customizeRequest)
.flatMap(request =>
http
Expand Down Expand Up @@ -133,23 +131,21 @@ class AkkaHttpBackend private (
val body = bodyFromAkka(
r.response,
responseMetadata,
wsFlow.map(Right(_)).getOrElse(Left(decodeAkkaResponse(hr, r.autoDecompressionDisabled)))
wsFlow.map(Right(_)).getOrElse(Left(decodeAkkaResponse(hr, r.autoDecompressionEnabled)))
)

body.map(client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata))
body.map(sttp.client4.Response(_, code, statusText, headers, Nil, r.onlyMetadata))
}

// http://doc.akka.io/docs/akka-http/10.0.7/scala/http/common/de-coding.html
private def decodeAkkaResponse(response: HttpResponse, disableAutoDecompression: Boolean): HttpResponse =
if (!response.status.allowsEntity() || disableAutoDecompression) response
else customEncodingHandler.orElse(EncodingHandler(standardEncoding)).apply(response -> response.encoding)

private def standardEncoding: (HttpResponse, HttpEncoding) => HttpResponse = {
case (body, HttpEncodings.gzip) => Coders.Gzip.decodeMessage(body)
case (body, HttpEncodings.deflate) => Coders.Deflate.decodeMessage(body)
case (body, HttpEncodings.identity) => Coders.NoCoding.decodeMessage(body)
case (_, ce) => throw new UnsupportedEncodingException(s"Unsupported encoding: $ce")
}
private def decodeAkkaResponse(response: HttpResponse, autoDecompressionEnabled: Boolean): HttpResponse =
if (!response.status.allowsEntity() || !autoDecompressionEnabled) response
else
response.encoding match {
case HttpEncodings.identity => response
case encoding: HttpEncoding =>
Decompressor.decompressIfPossible(response, encoding.value, compressionHandlers.decompressors)
}

private def adjustExceptions[T](request: GenericRequest[_, _])(t: => Future[T]): Future[T] =
SttpClientException.adjustExceptions(monad)(t)(FromAkka.exception(request, _))
Expand All @@ -165,12 +161,11 @@ class AkkaHttpBackend private (
}

object AkkaHttpBackend {
type EncodingHandler = PartialFunction[(HttpResponse, HttpEncoding), HttpResponse]
object EncodingHandler {
def apply(f: (HttpResponse, HttpEncoding) => HttpResponse): EncodingHandler = { case (body, encoding) =>
f(body, encoding)
}
}
val DefaultCompressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] =
CompressionHandlers(
List(GZipAkkaCompressor, DeflateAkkaCompressor),
List(GZipAkkaDecompressor, DeflateAkkaDecompressor)
)

private def make(
actorSystem: ActorSystem,
Expand All @@ -182,7 +177,7 @@ object AkkaHttpBackend {
customizeRequest: HttpRequest => HttpRequest,
customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity,
customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r,
customEncodingHandler: EncodingHandler = PartialFunction.empty
compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse]
): WebSocketStreamBackend[Future, AkkaStreams] =
FollowRedirectsBackend(
new AkkaHttpBackend(
Expand All @@ -195,7 +190,7 @@ object AkkaHttpBackend {
customizeRequest,
customizeWebsocketRequest,
customizeResponse,
customEncodingHandler
compressionHandlers
)
)

Expand All @@ -211,7 +206,7 @@ object AkkaHttpBackend {
customizeRequest: HttpRequest => HttpRequest = identity,
customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity,
customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r,
customEncodingHandler: EncodingHandler = PartialFunction.empty
compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers
)(implicit
ec: Option[ExecutionContext] = None
): WebSocketStreamBackend[Future, AkkaStreams] = {
Expand All @@ -227,7 +222,7 @@ object AkkaHttpBackend {
customizeRequest,
customizeWebsocketRequest,
customizeResponse,
customEncodingHandler
compressionHandlers
)
}

Expand All @@ -246,7 +241,7 @@ object AkkaHttpBackend {
customizeRequest: HttpRequest => HttpRequest = identity,
customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity,
customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r,
customEncodingHandler: EncodingHandler = PartialFunction.empty
compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers
)(implicit
ec: Option[ExecutionContext] = None
): WebSocketStreamBackend[Future, AkkaStreams] =
Expand All @@ -258,7 +253,7 @@ object AkkaHttpBackend {
customizeRequest,
customizeWebsocketRequest,
customizeResponse,
customEncodingHandler
compressionHandlers
)

/** @param actorSystem
Expand All @@ -275,7 +270,7 @@ object AkkaHttpBackend {
customizeRequest: HttpRequest => HttpRequest = identity,
customizeWebsocketRequest: WebSocketRequest => WebSocketRequest = identity,
customizeResponse: (HttpRequest, HttpResponse) => HttpResponse = (_, r) => r,
customEncodingHandler: EncodingHandler = PartialFunction.empty
compressionHandlers: CompressionHandlers[AkkaStreams, HttpResponse] = DefaultCompressionHandlers
)(implicit
ec: Option[ExecutionContext] = None
): WebSocketStreamBackend[Future, AkkaStreams] =
Expand All @@ -289,7 +284,7 @@ object AkkaHttpBackend {
customizeRequest,
customizeWebsocketRequest,
customizeResponse,
customEncodingHandler
compressionHandlers
)

/** Create a stub backend for testing, which uses the [[Future]] response wrapper, and doesn't support streaming.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,26 @@ import akka.http.scaladsl.model.{
import akka.stream.scaladsl.{Source, StreamConverters}
import akka.util.ByteString
import sttp.capabilities.akka.AkkaStreams
import sttp.client4.internal.throwNestedMultipartNotAllowed
import sttp.client4._
import sttp.model.{HeaderNames, Part}
import sttp.model.Part

import scala.collection.immutable.Seq
import scala.util.{Failure, Success, Try}
import sttp.client4.compression.Compressor

private[akkahttp] object BodyToAkka {
def apply[R](
r: GenericRequest[_, R],
body: GenericRequestBody[R],
ar: HttpRequest
ar: HttpRequest,
compressors: List[Compressor[R]]
): Try[HttpRequest] = {
def ctWithCharset(ct: ContentType, charset: String) =
HttpCharsets
.getForKey(charset)
.map(hc => ContentType.apply(ct.mediaType, () => hc))
.getOrElse(ct)

def contentLength = r.headers.find(_.is(HeaderNames.ContentLength)).flatMap(h => Try(h.value.toLong).toOption)
val (body, contentLength) = Compressor.compressIfNeeded(r, compressors)

def toBodyPart(mp: Part[BodyPart[_]]): Try[AkkaMultipart.FormData.BodyPart] = {
def streamPartEntity(contentType: ContentType, s: AkkaStreams.BinaryStream) =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sttp.client4.akkahttp

import sttp.client4.compression.Decompressor
import sttp.model.Encodings
import akka.http.scaladsl.coding.Coders
import akka.http.scaladsl.model.HttpResponse

object GZipAkkaDecompressor extends Decompressor[HttpResponse] {
override val encoding: String = Encodings.Gzip
override def apply(body: HttpResponse): HttpResponse = Coders.Gzip.decodeMessage(body)
}

object DeflateAkkaDecompressor extends Decompressor[HttpResponse] {
override val encoding: String = Encodings.Deflate
override def apply(body: HttpResponse): HttpResponse = Coders.Deflate.decodeMessage(body)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import sttp.capabilities.fs2.Fs2Streams
import sttp.client4.armeria.ArmeriaWebClient.newClient
import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage}
import sttp.client4.impl.cats.CatsMonadAsyncError
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, BackendOptions, StreamBackend}
import sttp.monad.MonadAsyncError
import sttp.client4.compression.Compressor
import sttp.client4.impl.fs2.DeflateFs2Compressor
import sttp.client4.impl.fs2.GZipFs2Compressor

private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFactory: Boolean, dispatcher: Dispatcher[F])
extends AbstractArmeriaBackend[F, Fs2Streams[F]](client, closeFactory, new CatsMonadAsyncError) {
Expand All @@ -41,6 +43,9 @@ private final class ArmeriaFs2Backend[F[_]: Async](client: WebClient, closeFacto
},
dispatcher
)

override protected def compressors: List[Compressor[R]] =
List(new GZipFs2Compressor[F, R](), new DeflateFs2Compressor[F, R]())
}

object ArmeriaFs2Backend {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import sttp.client4.internal.toByteArray
import sttp.model._
import sttp.monad.syntax._
import sttp.monad.{Canceler, MonadAsyncError}
import sttp.client4.compression.Compressor

abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
client: WebClient = WebClient.of(),
Expand All @@ -54,6 +55,8 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](

protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData]

protected def compressors: List[Compressor[R]] = Compressor.default[R]

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
monad.suspend(adjustExceptions(request)(execute(request)))

Expand Down Expand Up @@ -87,7 +90,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
} finally captor.close()
}

private def requestToArmeria(request: GenericRequest[_, Nothing]): WebClientRequestPreparation = {
private def requestToArmeria(request: GenericRequest[_, R]): WebClientRequestPreparation = {
val requestPreparation = client
.prepare()
.disablePathParams()
Expand All @@ -102,19 +105,24 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
requestPreparation.responseTimeoutMillis(Long.MaxValue)
}

val (body, contentLength) = Compressor.compressIfNeeded(request, compressors)

var customContentType: Option[ArmeriaMediaType] = None
request.headers.foreach { header =>
if (header.name.equalsIgnoreCase(HeaderNames.ContentType)) {
if (header.is(HeaderNames.ContentType)) {
// A Content-Type will be set with the body content
customContentType = Some(ArmeriaMediaType.parse(header.value))
} else {
requestPreparation.header(header.name, header.value)
} else if (!header.is(HeaderNames.ContentLength)) {
val _ = requestPreparation.header(header.name, header.value)
}
}
contentLength.foreach { cl =>
requestPreparation.header(HeaderNames.ContentLength, cl.toString)
}

val contentType = customContentType.getOrElse(ArmeriaMediaType.parse(request.body.defaultContentType.toString()))

request.body match {
body match {
case NoBody => requestPreparation
case StringBody(s, encoding, _) =>
val charset =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import sttp.client4.armeria.ArmeriaWebClient.newClient
import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage}
import sttp.client4.internal.NoStreams
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, Backend, BackendOptions}
import sttp.client4.{Backend, BackendOptions}
import sttp.monad.{FutureMonad, MonadAsyncError}

import scala.concurrent.ExecutionContext.Implicits.global
Expand Down Expand Up @@ -50,5 +50,5 @@ object ArmeriaFutureBackend {
apply(newClient(), closeFactory = false)

private def apply(client: WebClient, closeFactory: Boolean): Backend[Future] =
wrappers.FollowRedirectsBackend(new ArmeriaFutureBackend(client, closeFactory))
FollowRedirectsBackend(new ArmeriaFutureBackend(client, closeFactory))
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import sttp.capabilities.zio.ZioStreams
import sttp.client4.armeria.ArmeriaWebClient.newClient
import sttp.client4.armeria.{AbstractArmeriaBackend, BodyFromStreamMessage}
import sttp.client4.impl.zio.RIOMonadAsyncError
import sttp.client4.wrappers.FollowRedirectsBackend
import sttp.client4.{wrappers, BackendOptions, StreamBackend}
import sttp.monad.MonadAsyncError
import zio.stream.Stream
import sttp.client4.compression.Compressor
import sttp.client4.impl.zio.GZipZioCompressor
import sttp.client4.impl.zio.DeflateZioCompressor

private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient, closeFactory: Boolean)
extends AbstractArmeriaBackend[Task, ZioStreams](client, closeFactory, new RIOMonadAsyncError[Any]) {
Expand All @@ -40,6 +42,8 @@ private final class ArmeriaZioBackend(runtime: Runtime[Any], client: WebClient,
.run(stream.mapChunks(c => Chunk.single(HttpData.wrap(c.toArray))).toPublisher)
.getOrThrowFiberFailure()
}

override protected def compressors: List[Compressor[R]] = List(GZipZioCompressor, DeflateZioCompressor)
}

object ArmeriaZioBackend {
Expand Down
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ val pekkoStreams = "org.apache.pekko" %% "pekko-stream" % pekkoStreamVersion
val scalaTest = libraryDependencies ++= Seq("freespec", "funsuite", "flatspec", "wordspec", "shouldmatchers").map(m =>
"org.scalatest" %%% s"scalatest-$m" % "3.2.19" % Test
)
val scalaTestPlusScalaCheck = libraryDependencies += "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test

val zio1Version = "1.0.18"
val zio2Version = "2.1.14"
Expand Down Expand Up @@ -318,7 +319,8 @@ lazy val core = (projectMatrix in file("core"))
"com.softwaremill.sttp.shared" %%% "core" % sttpSharedVersion,
"com.softwaremill.sttp.shared" %%% "ws" % sttpSharedVersion
),
scalaTest
scalaTest,
scalaTestPlusScalaCheck
)
.settings(testServerSettings)
.jvmPlatform(
Expand Down
Loading

0 comments on commit 1a30f3a

Please sign in to comment.