Skip to content

Commit

Permalink
Simplify schema based header codecs (#3232)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 31, 2025
1 parent 41c09b9 commit 4f6f3d2
Show file tree
Hide file tree
Showing 22 changed files with 1,013 additions and 739 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,9 @@ private[cli] object CliEndpoint {
}
CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List())

case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List())
case HttpCodec.HeaderCustom(codec, _) =>
CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
codec.asInstanceOf[SimpleCodec[_, _]] match {
case SimpleCodec.Specified(method: Method) =>
CliEndpoint(methods = method)
Expand All @@ -126,14 +124,9 @@ private[cli] object CliEndpoint {
CliEndpoint(url = HttpOptions.Path(pathCodec) :: List())

case HttpCodec.Query(codec, _) =>
if (codec.isPrimitive)
CliEndpoint(url = HttpOptions.Query(codec) :: List())
else if (codec.isRecord)
CliEndpoint(url = codec.recordFields.map { case (_, codec) =>
HttpOptions.Query(codec)
}.toList)
else
CliEndpoint(url = HttpOptions.Query(codec) :: List())
CliEndpoint(url = codec.recordFields.map { case (f, codec) =>
HttpOptions.Query(codec, f.fieldName)
}.toList)
case HttpCodec.Status(_, _) => CliEndpoint.empty

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import zio.schema.annotation.description
import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.internal.StringSchemaCodec
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/*
* HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest].
Expand Down Expand Up @@ -265,11 +267,10 @@ private[cli] object HttpOptions {

}

final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions {
final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions {
self =>
override val name = codec.name.get
override val tag = "?" + name
def options: Options[_] = optionsFromSchema(codec)(name)
def options: Options[_] = optionsFromSchema(codec.schema)(name)

override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc)

Expand All @@ -293,8 +294,8 @@ private[cli] object HttpOptions {

}

private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] =
codec.schema match {
private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] =
schema match {
case Schema.Primitive(standardType, _) =>
standardType match {
case StandardType.UnitType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ object CommandGen {
case _: HttpOptions.Constant => false
case _ => true
}.map {
case HttpOptions.Path(pathCodec, _) =>
case HttpOptions.Path(pathCodec, _) =>
pathCodec.segments.toList.flatMap { segment =>
getSegment(segment) match {
case (_, "") => Nil
case (name, "boolean") => s"[${getName(name, "")}]" :: Nil
case (name, codec) => s"${getName(name, "")} $codec" :: Nil
}
}
case HttpOptions.Query(codec, _) if codec.isPrimitive =>
case HttpOptions.Query(codec, name, _) =>
getType(codec.schema) match {
case "" => s"[${getName(codec.name.get, "")}]" :: Nil
case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil
case "" => s"[${getName(name, "")}]" :: Nil
case tpy => s"${getName(name, "")} $tpy" :: Nil
}
case _ => Nil
case _ => Nil
}.foldRight(List[String]())(_ ++ _)

val headersOptions = cliEndpoint.headers.filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ object EndpointGen {
lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] =
Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) =>
val schema = schema0.asInstanceOf[Schema[Any]]
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
HttpCodec.Query(codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
codec,
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.endpoint.cli.AuxGen._
import zio.http.endpoint.cli.CliRepr._
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/**
* Constructs a Gen[Options[CliRequest], CliEndpoint]
Expand All @@ -33,10 +34,10 @@ object OptionsGen {
.optionsFromTextCodec(textCodec)(name)
.map(value => textCodec.encode(value))

def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] =
def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] =
HttpOptions
.optionsFromSchema(codec)(name)
.map(value => codec.stringCodec.encode(value))
.optionsFromSchema(schema)(name)
.map(value => codec.encode(value))

lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] =
Gen
Expand Down Expand Up @@ -80,10 +81,10 @@ object OptionsGen {
.alphaNumericStringBounded(1, 30)
.zip(anyStandardType)
.map { case (name, schema) =>
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
encodeOptions(name, codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]),
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
},
)
Expand Down
2 changes: 1 addition & 1 deletion zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ object NettyBody extends BodyEncoding {
lazy val loop: ZChannel[Any, Any, Any, Any, E, Chunk[A], Unit] =
ZChannel.unwrap(
queue.take
.flatMap(_.done)
.flatMap(_.exit)
.fold(
maybeError =>
ZChannel.fromZIO(queue.shutdown) *>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package zio.http.netty.model

import scala.collection.AbstractIterator
import scala.jdk.CollectionConverters.CollectionHasAsScala

import zio.Chunk

import zio.http.Server.Config.CompressionOptions
import zio.http._
Expand Down Expand Up @@ -58,10 +61,10 @@ private[netty] object Conversions {

def headersToNetty(headers: Headers): HttpHeaders =
headers match {
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
}

def urlToNetty(url: URL): String = {
Expand Down Expand Up @@ -89,6 +92,7 @@ private[netty] object Conversions {
(headers: HttpHeaders) => nettyHeadersIterator(headers),
// NOTE: Netty's headers.get is case-insensitive
(headers: HttpHeaders, key: CharSequence) => headers.get(key),
(headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)),
(headers: HttpHeaders, key: CharSequence) => headers.contains(key),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package zio.http.endpoint

import java.time.Instant

import scala.math.BigDecimal.javaBigDecimal2bigDecimal

import zio._
import zio.test._

Expand Down
141 changes: 137 additions & 4 deletions zio-http/shared/src/main/scala/zio/http/Header.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ import scala.util.{Either, Failure, Success, Try}
import zio.Config.Secret
import zio._

import zio.http.codec.RichTextCodec
import zio.http.internal.DateEncoding
import zio.schema.Schema
import zio.schema.codec.DecodeError.ReadError

import zio.http.Header.HeaderTypeBase.Typed
import zio.http.codec.{HttpCodecError, RichTextCodec}
import zio.http.internal.{DateEncoding, ErrorConstructor, StringSchemaCodec}

sealed trait Header {
type Self <: Header
Expand All @@ -50,21 +54,150 @@ sealed trait Header {

object Header {

sealed trait HeaderType {
sealed trait HeaderTypeBase {
type HeaderValue

def names: Chunk[String]

def fromHeaders(headers: Headers): Either[String, HeaderValue]

private[http] def fromHeadersUnsafe(headers: Headers): HeaderValue

def toHeaders(value: HeaderValue): Headers
}

object HeaderTypeBase {
type Typed[HV] = HeaderTypeBase { type HeaderValue = HV }
}

sealed trait SchemaHeaderType extends HeaderTypeBase {
def schema: Schema[HeaderValue]

def optional: HeaderTypeBase.Typed[Option[HeaderValue]]
}

object SchemaHeaderType {
type Typed[H] = SchemaHeaderType { type HeaderValue = H }

def apply[H](implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = {
new SchemaHeaderType {
type HeaderValue = H
val schema: Schema[H] = schema0
val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema(
schema,
(h: Headers, k: String, v: String) => h.addHeader(k, v),
(h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs),
(h: Headers, k: String) => h.contains(k),
(h: Headers, k: String) => h.getUnsafe(k),
(h: Headers, k: String) => h.rawHeaders(k),
(h: Headers, k: String) => h.rawHeaders(k).size,
ErrorConstructor(
param => HttpCodecError.MissingHeader(param),
params => HttpCodecError.MissingHeaders(params),
validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors),
(param, value) => HttpCodecError.DecodingErrorHeader(param, value),
(param, expected, actual) => HttpCodecError.InvalidHeaderCount(param, expected, actual),
),
isKebabCase = true,
null,
)

override def names: Chunk[String] =
codec.recordFields.map(_._1.fieldName)

override def optional: SchemaHeaderType.Typed[Option[H]] =
apply(schema.optional)

override def fromHeaders(headers: Headers): Either[String, H] =
try Right(codec.decode(headers))
catch {
case NonFatal(e) => Left(e.getMessage)
}

private[http] override def fromHeadersUnsafe(headers: Headers): H =
codec.decode(headers)

override def toHeaders(value: H): Headers =
codec.encode(value, Headers.empty)
}
}

def apply[H](name: String)(implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = {
new SchemaHeaderType {
type HeaderValue = H
val schema: Schema[H] = schema0
val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema(
schema,
(h: Headers, k: String, v: String) => h.addHeader(k, v),
(h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs),
(h: Headers, k: String) => h.contains(k),
(h: Headers, k: String) => h.getUnsafe(k),
(h: Headers, k: String) => h.rawHeaders(k),
(h: Headers, k: String) => h.rawHeaders(k).size,
ErrorConstructor(
header => HttpCodecError.MissingHeader(header),
headers => HttpCodecError.MissingHeaders(headers),
validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors),
(header, value) => HttpCodecError.DecodingErrorHeader(header, value),
(header, expected, actual) => HttpCodecError.InvalidHeaderCount(header, expected, actual),
),
isKebabCase = true,
name,
)

override def names: Chunk[String] =
codec.recordFields.map(_._1.fieldName)

override def optional: SchemaHeaderType.Typed[Option[H]] =
apply(name)(schema.optional)

override def fromHeaders(headers: Headers): Either[String, H] =
try Right(codec.decode(headers))
catch {
case NonFatal(e) => Left(e.getMessage)
}

private[http] override def fromHeadersUnsafe(headers: Headers): H =
codec.decode(headers)

override def toHeaders(value: H): Headers =
codec.encode(value, Headers.empty)
}
}
}

sealed trait HeaderType extends HeaderTypeBase {
type HeaderValue <: Header

def names: Chunk[String] = Chunk.single(name)

def name: String

def parse(value: String): Either[String, HeaderValue]

def render(value: HeaderValue): String

def fromHeaders(headers: Headers): Either[String, HeaderValue] =
headers.getUnsafe(name) match {
case null => Left(s"Header $name not found")
case value => parse(value)
}

def fromHeadersUnsafe(headers: Headers): HeaderValue =
fromHeaders(headers).fold(
e => throw HttpCodecError.DecodingErrorHeader(name, ReadError(Cause.empty, e)),
identity,
)

def toHeaders(value: HeaderValue): Headers =
Headers.FromIterable(Iterable(value))

}

object HeaderType {
type Typed[HV] = HeaderType { type HeaderValue = HV }
}

// @deprecated("Use Schema based header codecs instead", "3.1.0")
final case class Custom(customName: CharSequence, value: CharSequence) extends Header {
override type Self = Custom
override def self: Self = this
Expand Down
4 changes: 4 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/Headers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,17 @@ object Headers {
value: T,
iterate: T => Iterator[Header],
unsafeGet: (T, CharSequence) => String,
getAll: (T, CharSequence) => Chunk[String],
contains: (T, CharSequence) => Boolean,
) extends Headers {
override def contains(key: CharSequence): Boolean = contains(value, key)

override def iterator: Iterator[Header] = iterate(value)

override private[http] def getUnsafe(key: CharSequence): String = unsafeGet(value, key)

override def rawHeaders(name: CharSequence): Chunk[String] = getAll(value, name)

}

private[zio] final case class Concat(first: Headers, second: Headers) extends Headers {
Expand Down
Loading

0 comments on commit 4f6f3d2

Please sign in to comment.