Skip to content

Commit

Permalink
WSRequest: Normalize URL
Browse files Browse the repository at this point in the history
  • Loading branch information
htmldoug committed Oct 29, 2018
1 parent 904f5a4 commit 64e02f9
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@
*/
package play.api.libs.ws.ahc

import javax.inject.Inject
import java.net.URLDecoder
import java.util.Collections

import akka.stream.Materializer
import akka.stream.scaladsl.Source
import akka.util.ByteString
import com.typesafe.sslconfig.ssl.SystemConfiguration
import com.typesafe.sslconfig.ssl.debug.DebugConfiguration
import javax.inject.Inject
import play.api.libs.ws.ahc.cache._
import play.api.libs.ws.{ EmptyBody, StandaloneWSClient, StandaloneWSRequest }
import play.shaded.ahc.org.asynchttpclient.uri.Uri
import play.shaded.ahc.org.asynchttpclient.util.UriEncoder
import play.shaded.ahc.org.asynchttpclient.{ Response => AHCResponse, _ }

import scala.collection.immutable.TreeMap
import scala.compat.java8.FunctionConverters
import scala.concurrent.{ Await, Future, Promise }
import scala.util.control.NonFatal

/**
* A WS client backed by an AsyncHttpClient.
Expand All @@ -39,8 +43,7 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
}

def url(url: String): StandaloneWSRequest = {
validate(url)
StandaloneAhcWSRequest(
val req = StandaloneAhcWSRequest(
client = this,
url = url,
method = "GET",
Expand All @@ -56,6 +59,8 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
proxyServer = None,
disableUrlEncoding = None
)

StandaloneAhcWSClient.normalize(req)
}

private[ahc] def execute(request: Request): Future[StandaloneAhcWSResponse] = {
Expand All @@ -75,18 +80,6 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici
result.future
}

private def validate(url: String): Unit = {
// Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149
try {
Uri.create(url)
} catch {
case iae: IllegalArgumentException =>
throw new IllegalArgumentException(s"Invalid URL $url", iae)
case npe: NullPointerException =>
throw new IllegalArgumentException(s"Invalid URL $url", npe)
}
}

private[ahc] def executeStream(request: Request): Future[StreamedResponse] = {
val promise = Promise[StreamedResponse]()

Expand Down Expand Up @@ -116,12 +109,12 @@ class StandaloneAhcWSClient @Inject() (asyncHttpClient: AsyncHttpClient)(implici

Await.result(result, StandaloneAhcWSClient.blockingTimeout)
}

}

object StandaloneAhcWSClient {

import scala.concurrent.duration._

val blockingTimeout = 50.milliseconds
val elementLimit = 13 // 13 8192k blocks is roughly 100k
private val logger = org.slf4j.LoggerFactory.getLogger(this.getClass)
Expand Down Expand Up @@ -163,5 +156,104 @@ object StandaloneAhcWSClient {
new SystemConfiguration(loggerFactory).configure(config.wsClientConfig.ssl)
wsClient
}

/**
* Ensures:
* 1. [[StandaloneWSRequest.url]] path is encoded.
* 2. Any query params present in the URL are moved to [[StandaloneWSRequest.queryString]].
*/
@throws[IllegalArgumentException]("if the url is unrepairable")
private[ahc] def normalize(req: StandaloneWSRequest): StandaloneWSRequest = {
try {
// Recover from https://github.com/AsyncHttpClient/async-http-client/issues/1149
Uri.create(req.url)
if (req.uri.getQuery == null) {
// happy path
req
} else {
// valid, but move query params into the Map
repair(req)
}
} catch {
case NonFatal(_) =>
// URI parsing error
repair(req)
}
}

@throws[IllegalArgumentException]("if the url is unrepairable")
private def repair(req: StandaloneWSRequest): StandaloneWSRequest = {
try {
val encodedAhcUri: Uri = toUri(req)
setUri(req, encodedAhcUri)
} catch {
case NonFatal(t) =>
throw new IllegalArgumentException(s"Invalid URL ${req.url}", t)
}
}

/**
* Builds an AHC [[Uri]] with all parts URL encoded.
* Combines both [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]].
*/
private def toUri(req: StandaloneWSRequest): Uri = {
val combinedUri: Uri = {
val uri = Uri.create(req.url)

val params = req.queryString
if (params.nonEmpty) {
appendParamsToUri(uri, params)
} else {
uri
}
}

// FIXING.encode() encodes ONLY unencoded parts, leaving encoded parts untouched.
UriEncoder.FIXING.encode(combinedUri, Collections.emptyList())
}

/**
* Replace the [[StandaloneWSRequest.url]] and [[StandaloneWSRequest.queryString]]
* with the values of [[uri]], discarding originals.
*/
private def setUri(req: StandaloneWSRequest, uri: Uri): StandaloneWSRequest = {
val urlNoQueryParams = uri.withNewQuery(null).toUrl
uri.getHost

val queryParams: List[(String, String)] = for {
queryString <- Option(uri.getQuery).toList
// https://stackoverflow.com/a/13592567 for all of this.
pair <- queryString.split('&')
idx = pair.indexOf('=')
key = if (idx > 0) pair.substring(0, idx) else pair
value = if (idx > 0) URLDecoder.decode(pair.substring(idx + 1)) else ""
} yield key -> value

req.withUrl(urlNoQueryParams)
.withQueryStringParameters(queryParams: _*)
}

private def appendParamsToUri(uri: Uri, params: Map[String, Seq[String]]): Uri = {
val sb = new StringBuilder
// Reminder: ahc.Uri does not start with '?' (unlike java.net.URI)
if (uri.getQuery != null) {
sb.append(uri.getQuery)
}

for {
(key, values) <- params
value <- values
} {
if (sb.length > 0) {
sb.append('&')
}
sb.append(key)
if (value.nonEmpty) {
sb.append('=').append(value)
}
}

uri.withNewQuery(sb.toString)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.nio.charset.{ Charset, StandardCharsets }

import akka.stream.Materializer
import akka.stream.scaladsl.Sink
import play.api.libs.ws.{ StandaloneWSRequest, _ }
import play.api.libs.ws._
import play.shaded.ahc.io.netty.buffer.Unpooled
import play.shaded.ahc.io.netty.handler.codec.http.HttpHeaders
import play.shaded.ahc.org.asynchttpclient.Realm.AuthScheme
Expand Down Expand Up @@ -184,7 +184,10 @@ case class StandaloneAhcWSRequest(
withMethod(method).execute()
}

override def withUrl(url: String): Self = copy(url = url)
override def withUrl(url: String): Self = {
val unsafe = copy(url = url)
StandaloneAhcWSClient.normalize(unsafe)
}

override def withMethod(method: String): Self = copy(method = method)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ class AhcWSRequestSpec extends Specification with Mockito with AfterAll with Def

}

"with unencoded values" in {
withClient { client =>
val request = client.url("http://www.example.com/|?!")
.addQueryStringParameters("#" -> "$")
.addQueryStringParameters("^" -> "*", "^" -> "(")

val uri = request.uri
uri.getPath === "/|"
uri.getQuery.split('&').toSeq must contain(exactly("!=", "#=$", "^=*", "^=("))

request.url === "http://www.example.com/%7C"
request.queryString must contain("!" -> Seq(""))
request.queryString must contain("#" -> Seq("$"))
request.queryString.get("^") must beSome.which(_ must contain(exactly("*", "(")))
}

}

}

"For Cookies" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ trait StandaloneWSClient extends Closeable {
* @param url The base URL to make HTTP requests to.
* @return a request
*/
@throws[IllegalArgumentException]
@throws[IllegalArgumentException]("if the URL is invalid")
def url(url: String): StandaloneWSRequest

/**
Expand Down

0 comments on commit 64e02f9

Please sign in to comment.