diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala index 3228b85c3..298581198 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/CliEndpoint.scala @@ -111,11 +111,9 @@ private[cli] object CliEndpoint { } CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List()) - case HttpCodec.Header(headerType, _) => - CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List()) - case HttpCodec.HeaderCustom(codec, _) => - CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List()) - case HttpCodec.Method(codec, _) => + case HttpCodec.Header(headerType, _) => + CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List()) + case HttpCodec.Method(codec, _) => codec.asInstanceOf[SimpleCodec[_, _]] match { case SimpleCodec.Specified(method: Method) => CliEndpoint(methods = method) @@ -126,14 +124,9 @@ private[cli] object CliEndpoint { CliEndpoint(url = HttpOptions.Path(pathCodec) :: List()) case HttpCodec.Query(codec, _) => - if (codec.isPrimitive) - CliEndpoint(url = HttpOptions.Query(codec) :: List()) - else if (codec.isRecord) - CliEndpoint(url = codec.recordFields.map { case (_, codec) => - HttpOptions.Query(codec) - }.toList) - else - CliEndpoint(url = HttpOptions.Query(codec) :: List()) + CliEndpoint(url = codec.recordFields.map { case (f, codec) => + HttpOptions.Query(codec, f.fieldName) + }.toList) case HttpCodec.Status(_, _) => CliEndpoint.empty } diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 2abb5704b..9d20c8154 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -13,6 +13,8 @@ import zio.schema.annotation.description import zio.http._ import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ +import zio.http.internal.StringSchemaCodec +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /* * HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest]. @@ -265,11 +267,10 @@ private[cli] object HttpOptions { } - final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions { + final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions { self => - override val name = codec.name.get override val tag = "?" + name - def options: Options[_] = optionsFromSchema(codec)(name) + def options: Options[_] = optionsFromSchema(codec.schema)(name) override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc) @@ -293,8 +294,8 @@ private[cli] object HttpOptions { } - private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] = - codec.schema match { + private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] = + schema match { case Schema.Primitive(standardType, _) => standardType match { case StandardType.UnitType => diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index 9d84e04b9..c724eaa21 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -47,7 +47,7 @@ object CommandGen { case _: HttpOptions.Constant => false case _ => true }.map { - case HttpOptions.Path(pathCodec, _) => + case HttpOptions.Path(pathCodec, _) => pathCodec.segments.toList.flatMap { segment => getSegment(segment) match { case (_, "") => Nil @@ -55,12 +55,12 @@ object CommandGen { case (name, codec) => s"${getName(name, "")} $codec" :: Nil } } - case HttpOptions.Query(codec, _) if codec.isPrimitive => + case HttpOptions.Query(codec, name, _) => getType(codec.schema) match { - case "" => s"[${getName(codec.name.get, "")}]" :: Nil - case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil + case "" => s"[${getName(name, "")}]" :: Nil + case tpy => s"${getName(name, "")} $tpy" :: Nil } - case _ => Nil + case _ => Nil }.foldRight(List[String]())(_ ++ _) val headersOptions = cliEndpoint.headers.filter { diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala index 792cbdb2f..92b596f9f 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/EndpointGen.scala @@ -103,10 +103,10 @@ object EndpointGen { lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] = Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) => val schema = schema0.asInstanceOf[Schema[Any]] - val codec = SchemaCodec(Some(name), schema) + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - HttpCodec.Query(codec), - CliEndpoint(url = HttpOptions.Query(codec) :: Nil), + codec, + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) } diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala index 1cb6016f4..64951c253 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/OptionsGen.scala @@ -11,6 +11,7 @@ import zio.http.codec.HttpCodec.SchemaCodec import zio.http.codec._ import zio.http.endpoint.cli.AuxGen._ import zio.http.endpoint.cli.CliRepr._ +import zio.http.internal.StringSchemaCodec.PrimitiveCodec /** * Constructs a Gen[Options[CliRequest], CliEndpoint] @@ -33,10 +34,10 @@ object OptionsGen { .optionsFromTextCodec(textCodec)(name) .map(value => textCodec.encode(value)) - def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] = + def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] = HttpOptions - .optionsFromSchema(codec)(name) - .map(value => codec.stringCodec.encode(value)) + .optionsFromSchema(schema)(name) + .map(value => codec.encode(value)) lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] = Gen @@ -80,10 +81,10 @@ object OptionsGen { .alphaNumericStringBounded(1, 30) .zip(anyStandardType) .map { case (name, schema) => - val codec = SchemaCodec(Some(name), schema) + val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]] CliRepr( - encodeOptions(name, codec), - CliEndpoint(url = HttpOptions.Query(codec) :: Nil), + encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]), + CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil), ) }, ) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala index c5ba490ef..e66416e71 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala @@ -161,7 +161,7 @@ object NettyBody extends BodyEncoding { lazy val loop: ZChannel[Any, Any, Any, Any, E, Chunk[A], Unit] = ZChannel.unwrap( queue.take - .flatMap(_.done) + .flatMap(_.exit) .fold( maybeError => ZChannel.fromZIO(queue.shutdown) *> diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala index 374c9ba27..2c92326f4 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala @@ -17,6 +17,9 @@ package zio.http.netty.model import scala.collection.AbstractIterator +import scala.jdk.CollectionConverters.CollectionHasAsScala + +import zio.Chunk import zio.http.Server.Config.CompressionOptions import zio.http._ @@ -58,10 +61,10 @@ private[netty] object Conversions { def headersToNetty(headers: Headers): HttpHeaders = headers match { - case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) - case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders] - case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) - case Headers.Empty => new DefaultHttpHeaders() + case Headers.FromIterable(_) => encodeHeaderListToNetty(headers) + case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders] + case Headers.Concat(_, _) => encodeHeaderListToNetty(headers) + case Headers.Empty => new DefaultHttpHeaders() } def urlToNetty(url: URL): String = { @@ -89,6 +92,7 @@ private[netty] object Conversions { (headers: HttpHeaders) => nettyHeadersIterator(headers), // NOTE: Netty's headers.get is case-insensitive (headers: HttpHeaders, key: CharSequence) => headers.get(key), + (headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)), (headers: HttpHeaders, key: CharSequence) => headers.contains(key), ) diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala index a2c6334e6..3a093669b 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/EndpointSpec.scala @@ -18,6 +18,8 @@ package zio.http.endpoint import java.time.Instant +import scala.math.BigDecimal.javaBigDecimal2bigDecimal + import zio._ import zio.test._ diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index f5515e142..6a59192f0 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -31,8 +31,12 @@ import scala.util.{Either, Failure, Success, Try} import zio.Config.Secret import zio._ -import zio.http.codec.RichTextCodec -import zio.http.internal.DateEncoding +import zio.schema.Schema +import zio.schema.codec.DecodeError.ReadError + +import zio.http.Header.HeaderTypeBase.Typed +import zio.http.codec.{HttpCodecError, RichTextCodec} +import zio.http.internal.{DateEncoding, ErrorConstructor, StringSchemaCodec} sealed trait Header { type Self <: Header @@ -50,21 +54,150 @@ sealed trait Header { object Header { - sealed trait HeaderType { + sealed trait HeaderTypeBase { + type HeaderValue + + def names: Chunk[String] + + def fromHeaders(headers: Headers): Either[String, HeaderValue] + + private[http] def fromHeadersUnsafe(headers: Headers): HeaderValue + + def toHeaders(value: HeaderValue): Headers + } + + object HeaderTypeBase { + type Typed[HV] = HeaderTypeBase { type HeaderValue = HV } + } + + sealed trait SchemaHeaderType extends HeaderTypeBase { + def schema: Schema[HeaderValue] + + def optional: HeaderTypeBase.Typed[Option[HeaderValue]] + } + + object SchemaHeaderType { + type Typed[H] = SchemaHeaderType { type HeaderValue = H } + + def apply[H](implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = schema0 + val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema( + schema, + (h: Headers, k: String, v: String) => h.addHeader(k, v), + (h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs), + (h: Headers, k: String) => h.contains(k), + (h: Headers, k: String) => h.getUnsafe(k), + (h: Headers, k: String) => h.rawHeaders(k), + (h: Headers, k: String) => h.rawHeaders(k).size, + ErrorConstructor( + param => HttpCodecError.MissingHeader(param), + params => HttpCodecError.MissingHeaders(params), + validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors), + (param, value) => HttpCodecError.DecodingErrorHeader(param, value), + (param, expected, actual) => HttpCodecError.InvalidHeaderCount(param, expected, actual), + ), + isKebabCase = true, + null, + ) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + + def apply[H](name: String)(implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = { + new SchemaHeaderType { + type HeaderValue = H + val schema: Schema[H] = schema0 + val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema( + schema, + (h: Headers, k: String, v: String) => h.addHeader(k, v), + (h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs), + (h: Headers, k: String) => h.contains(k), + (h: Headers, k: String) => h.getUnsafe(k), + (h: Headers, k: String) => h.rawHeaders(k), + (h: Headers, k: String) => h.rawHeaders(k).size, + ErrorConstructor( + header => HttpCodecError.MissingHeader(header), + headers => HttpCodecError.MissingHeaders(headers), + validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors), + (header, value) => HttpCodecError.DecodingErrorHeader(header, value), + (header, expected, actual) => HttpCodecError.InvalidHeaderCount(header, expected, actual), + ), + isKebabCase = true, + name, + ) + + override def names: Chunk[String] = + codec.recordFields.map(_._1.fieldName) + + override def optional: SchemaHeaderType.Typed[Option[H]] = + apply(name)(schema.optional) + + override def fromHeaders(headers: Headers): Either[String, H] = + try Right(codec.decode(headers)) + catch { + case NonFatal(e) => Left(e.getMessage) + } + + private[http] override def fromHeadersUnsafe(headers: Headers): H = + codec.decode(headers) + + override def toHeaders(value: H): Headers = + codec.encode(value, Headers.empty) + } + } + } + + sealed trait HeaderType extends HeaderTypeBase { type HeaderValue <: Header + def names: Chunk[String] = Chunk.single(name) + def name: String def parse(value: String): Either[String, HeaderValue] def render(value: HeaderValue): String + + def fromHeaders(headers: Headers): Either[String, HeaderValue] = + headers.getUnsafe(name) match { + case null => Left(s"Header $name not found") + case value => parse(value) + } + + def fromHeadersUnsafe(headers: Headers): HeaderValue = + fromHeaders(headers).fold( + e => throw HttpCodecError.DecodingErrorHeader(name, ReadError(Cause.empty, e)), + identity, + ) + + def toHeaders(value: HeaderValue): Headers = + Headers.FromIterable(Iterable(value)) + } object HeaderType { type Typed[HV] = HeaderType { type HeaderValue = HV } } - // @deprecated("Use Schema based header codecs instead", "3.1.0") final case class Custom(customName: CharSequence, value: CharSequence) extends Header { override type Self = Custom override def self: Self = this diff --git a/zio-http/shared/src/main/scala/zio/http/Headers.scala b/zio-http/shared/src/main/scala/zio/http/Headers.scala index 3df48e573..08c24ea01 100644 --- a/zio-http/shared/src/main/scala/zio/http/Headers.scala +++ b/zio-http/shared/src/main/scala/zio/http/Headers.scala @@ -99,6 +99,7 @@ object Headers { value: T, iterate: T => Iterator[Header], unsafeGet: (T, CharSequence) => String, + getAll: (T, CharSequence) => Chunk[String], contains: (T, CharSequence) => Boolean, ) extends Headers { override def contains(key: CharSequence): Boolean = contains(value, key) @@ -106,6 +107,9 @@ object Headers { override def iterator: Iterator[Header] = iterate(value) override private[http] def getUnsafe(key: CharSequence): String = unsafeGet(value, key) + + override def rawHeaders(name: CharSequence): Chunk[String] = getAll(value, name) + } private[zio] final case class Concat(first: Headers, second: Headers) extends Headers { diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala index ce21ab126..2bafeeaa4 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HeaderCodecs.scala @@ -24,7 +24,7 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.schema._ -import zio.http.Header.HeaderType +import zio.http.Header.{HeaderType, SchemaHeaderType} import zio.http._ private[codec] trait HeaderCodecs { @@ -41,17 +41,17 @@ private[codec] trait HeaderCodecs { case TextCodec.BooleanCodec => Schema[Boolean] case TextCodec.UUIDCodec => Schema[UUID] } - HttpCodec.HeaderCustom(name, schema.asInstanceOf[Schema[A]]) + HttpCodec.Header(SchemaHeaderType(name)(schema.asInstanceOf[Schema[A]])) } def header(headerType: HeaderType): HeaderCodec[headerType.HeaderValue] = HttpCodec.Header(headerType) def headerAs[A](name: String)(implicit schema: Schema[A]): HeaderCodec[A] = - HttpCodec.HeaderCustom(name, schema) + HttpCodec.Header(SchemaHeaderType(name)) def headers[A](implicit schema: Schema[A]): HeaderCodec[A] = - HttpCodec.HeaderCustom(schema) + HttpCodec.Header(SchemaHeaderType("headers")) @deprecated("Use Schema based headerAs instead", "3.1.0") def name[A](name: String)(implicit codec: TextCodec[A]): HeaderCodec[A] = diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala index 673dcc242..e2f9cf6b2 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodec.scala @@ -20,7 +20,7 @@ import scala.annotation.tailrec import scala.reflect.ClassTag import scala.util.Try -import zio._ +import zio.{http, _} import zio.stream.{ZPipeline, ZStream} @@ -29,12 +29,13 @@ import zio.schema.codec.DecodeError import zio.schema.validation.{Validation, ValidationError} import zio.http.Header.Accept.MediaTypeWithQFactor -import zio.http.Header.HeaderType +import zio.http.Header.{HeaderType, HeaderTypeBase, SchemaHeaderType} import zio.http._ import zio.http.codec.HttpCodec.SchemaCodec.camelToKebab import zio.http.codec.HttpCodec.{Annotated, Metadata} import zio.http.codec.StringCodec.StringCodec import zio.http.codec.internal._ +import zio.http.internal.StringSchemaCodec /** * A [[zio.http.codec.HttpCodec]] represents a codec for a part of an HTTP @@ -341,13 +342,12 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with private[http] sealed trait AtomTag private[http] object AtomTag { - case object Status extends AtomTag - case object Path extends AtomTag - case object Content extends AtomTag - case object Query extends AtomTag - case object Header extends AtomTag - case object HeaderCustom extends AtomTag - case object Method extends AtomTag + case object Status extends AtomTag + case object Path extends AtomTag + case object Content extends AtomTag + case object Query extends AtomTag + case object Header extends AtomTag + case object Method extends AtomTag } def empty: HttpCodec[Any, Unit] = @@ -2268,57 +2268,36 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): ContentStream[A] = copy(index = index) } - private[http] final case class Query[A, Out]( - codec: SchemaCodec[A], + private[http] final case class Query[A]( + codec: StringSchemaCodec[A, QueryParams], index: Int = 0, - ) extends Atom[HttpCodecType.Query, Out] { + ) extends Atom[HttpCodecType.Query, A] { self => - def erase: Query[Any, Any] = self.asInstanceOf[Query[Any, Any]] + def erase: Query[Any] = self.asInstanceOf[Query[Any]] - def index(index: Int): Query[A, Out] = copy(index = index) - - def isCollection: Boolean = codec.isCollection - - def isOptional: Boolean = codec.isOptional - - def isOptionalSchema: Boolean = codec.isOptionalSchema - - def isPrimitive: Boolean = codec.isPrimitive - - def isRecord: Boolean = codec.isRecord - - def nameUnsafe: String = codec.name.get + def index(index: Int): Query[A] = copy(index = index) /** * Returns a new codec, where the value produced by this one is optional. */ - override def optional: HttpCodec[HttpCodecType.Query, Option[Out]] = - if (isOptionalSchema) { - throw new IllegalArgumentException("Query is already optional") - } else { - Annotated(Query(codec.optional, index), Metadata.Optional()) - } + override def optional: HttpCodec[HttpCodecType.Query, Option[A]] = + Annotated(Query(codec.optional, index), Metadata.Optional()) def tag: AtomTag = AtomTag.Query } - object Query { - def apply[A](name: String, schema: Schema[A]): Query[A, A] = Query(SchemaCodec(Some(name), schema)) - def apply[A](schema: Schema[A]): Query[A, A] = Query(SchemaCodec(None, schema)) - } - - final case class SchemaCodec[A](name: Option[String], schema: Schema[A], kebabCase: Boolean = false) { + private[http] final case class SchemaCodec[A](name: Option[String], schema: Schema[A], kebabCase: Boolean = false) { - def erasedSchema: Schema[Any] = schema.asInstanceOf[Schema[Any]] + private[http] def erasedSchema: Schema[Any] = schema.asInstanceOf[Schema[Any]] - val isCollection: Boolean = schema match { + private[http] val isCollection: Boolean = schema match { case _: Schema.Collection[_, _] => true case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Collection[_, _]] => true case _ => false } - val isOptional: Boolean = schema match { + private[http] val isOptional: Boolean = schema match { case _: Schema.Optional[_] => true case record: Schema.Record[_] => @@ -2329,30 +2308,30 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with false } - val isOptionalSchema: Boolean = + private[http] val isOptionalSchema: Boolean = schema match { case _: Schema.Optional[_] => true case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true case _ => false } - val isPrimitive: Boolean = schema match { + private[http] val isPrimitive: Boolean = schema match { case _: Schema.Primitive[_] => true case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Primitive[_]] => true case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Primitive[_]] => true case _ => false } - val isRecord: Boolean = schema match { + private[http] val isRecord: Boolean = schema match { case _: Schema.Record[_] => true case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => true case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => true case _ => false } - def optional: SchemaCodec[Option[A]] = copy(schema = schema.optional) + private[http] def optional: SchemaCodec[Option[A]] = copy(schema = schema.optional) - val recordFields: Chunk[(Schema.Field[_, _], SchemaCodec[Any])] = { + private[http] val recordFields: Chunk[(Schema.Field[_, _], SchemaCodec[Any])] = { val fields = schema match { case record: Schema.Record[A] => record.fields @@ -2371,6 +2350,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") } + // only right for headers not for query parameters val codec = SchemaCodec(Some(if (!kebabCase) field.name else camelToKebab(field.name)), elementSchema) (field, codec.asInstanceOf[SchemaCodec[Any]]) case field => @@ -2382,7 +2362,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with } } - val recordSchema: Schema.Record[Any] = schema match { + private[http] val recordSchema: Schema.Record[Any] = schema match { case record: Schema.Record[_] => record.asInstanceOf[Schema.Record[Any]] case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => @@ -2390,7 +2370,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with case _ => null } - val stringCodec: StringCodec[Any] = + private[http] val stringCodec: StringCodec[Any] = stringCodecForSchema(schema.asInstanceOf[Schema[Any]]) private def stringCodecForSchema(s: Schema[_]): StringCodec[Any] = { @@ -2442,14 +2422,14 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with case f => f } - def validate(value: Any): Chunk[ValidationError] = + private[http] def validate(value: Any): Chunk[ValidationError] = schema.asInstanceOf[Schema[_]] match { case Schema.Optional(schema: Schema[Any], _) => schema.validate(value)(schema) case schema: Schema[_] => schema.asInstanceOf[Schema[Any]].validate(value)(schema.asInstanceOf[Schema[Any]]) } - val defaultValue: A = + val defaultValue: A = if (schema.isInstanceOf[Schema.Collection[_, _]]) { Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( _ => null.asInstanceOf[A], @@ -2473,7 +2453,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with } - object SchemaCodec { + private[http] object SchemaCodec { private def camelToKebab(s: String): String = if (s.isEmpty) "" else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) @@ -2494,34 +2474,7 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def index(index: Int): Method[A] = copy(index = index) } - private[http] final case class HeaderCustom[A](codec: SchemaCodec[A], index: Int = 0) - extends Atom[HttpCodecType.Header, A] { - self => - def erase: HeaderCustom[Any] = self.asInstanceOf[HeaderCustom[Any]] - - override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = - if (codec.isOptionalSchema) { - throw new IllegalArgumentException("Header is already optional") - } else { - Annotated( - HeaderCustom(codec.optional, index), - Metadata.Optional(), - ) - } - - def tag: AtomTag = AtomTag.HeaderCustom - - def index(index: Int): HeaderCustom[A] = copy(index = index) - } - - object HeaderCustom { - def apply[A](name: String, schema: Schema[A]): HeaderCustom[A] = - HeaderCustom(SchemaCodec(Some(name), schema, kebabCase = true)) - def apply[A](schema: Schema[A]): HeaderCustom[A] = - HeaderCustom(SchemaCodec(None, schema, kebabCase = true)) - } - - private[http] final case class Header[A](headerType: HeaderType.Typed[A], index: Int = 0) + private[http] final case class Header[A](headerType: HeaderTypeBase.Typed[A], index: Int = 0) extends Atom[HttpCodecType.Header, A] { self => def erase: Header[Any] = self.asInstanceOf[Header[Any]] @@ -2529,6 +2482,18 @@ object HttpCodec extends ContentCodecs with HeaderCodecs with MethodCodecs with def tag: AtomTag = AtomTag.Header def index(index: Int): Header[A] = copy(index = index) + + override def optional: HttpCodec[HttpCodecType.Header, Option[A]] = { + headerType match { + case headerType if headerType.isInstanceOf[SchemaHeaderType] => + Annotated( + Header(headerType.asInstanceOf[SchemaHeaderType.Typed[A]].optional, index), + Metadata.Optional(), + ) + case _ => + super.optional + } + } } private[http] final case class Annotated[AtomTypes, Value]( diff --git a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala index 3df1973ab..bfbacea8d 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -52,8 +52,8 @@ object HttpCodecError { final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } - final case class MalformedCustomHeader(headerName: String, cause: DecodeError) extends HttpCodecError { - def message = s"Malformed custom header $headerName could not be decoded: $cause" + final case class DecodingErrorHeader(headerName: String, cause: DecodeError) extends HttpCodecError { + def message = s"Malformed header $headerName could not be decoded: $cause" } final case class MalformedTypedHeader(headerName: String) extends HttpCodecError { def message = s"Malformed header $headerName" @@ -83,6 +83,9 @@ object HttpCodecError { final case class InvalidQueryParamCount(name: String, expected: Int, actual: Int) extends HttpCodecError { def message = s"Invalid query parameter count for $name: expected $expected but found $actual." } + final case class InvalidHeaderCount(name: String, expected: Int, actual: Int) extends HttpCodecError { + def message = s"Invalid query parameter count for $name: expected $expected but found $actual." + } final case class CustomError(name: String, message: String) extends HttpCodecError final case class UnsupportedContentType(contentType: String) extends HttpCodecError { @@ -102,6 +105,9 @@ object HttpCodecError { def isMissingDataOnly(cause: Cause[Any]): Boolean = !cause.isFailure && cause.defects.forall(e => - e.isInstanceOf[HttpCodecError.MissingHeader] || e.isInstanceOf[HttpCodecError.MissingQueryParam], + e.isInstanceOf[HttpCodecError.MissingHeader] + || e.isInstanceOf[HttpCodecError.MissingQueryParam] + || e.isInstanceOf[HttpCodecError.MissingQueryParams] + || e.isInstanceOf[HttpCodecError.MissingHeaders], ) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala index 4f98ec8e4..3b28c8edc 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/QueryCodecs.scala @@ -15,65 +15,57 @@ */ package zio.http.codec -import scala.annotation.tailrec - import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.schema.Schema -import zio.schema.annotation.simpleEnum -private[codec] trait QueryCodecs { +import zio.http.QueryParams +import zio.http.internal.{ErrorConstructor, StringSchemaCodec} - def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case c: Schema.Collection[_, _] if !supportedCollection(c) => - throw new IllegalArgumentException(s"Collection schema $c is not supported for query codecs") - case enum0: Schema.Enum[_] if !enum0.annotations.exists(_.isInstanceOf[simpleEnum]) => - throw new IllegalArgumentException(s"Enum schema $enum0 is not supported. All cases must be objects.") - case record: Schema.Record[A] if record.fields.size != 1 => - throw new IllegalArgumentException("Use queryAll[A] for records with more than one field") - case record: Schema.Record[A] if !supportedElementSchema(record.fields.head.schema.asInstanceOf[Schema[Any]]) => - throw new IllegalArgumentException( - s"Only primitive types and simple enums can be used in single field records, but got ${record.fields.head.schema}", - ) - case other => - HttpCodec.Query(name, other) - } +private[codec] trait QueryCodecs { - private def supportedCollection(schema: Schema.Collection[_, _]): Boolean = schema match { - case Schema.Map(_, _, _) => - false - case Schema.NonEmptyMap(_, _, _) => - false - case Schema.Sequence(elementSchema, _, _, _, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) - case Schema.NonEmptySequence(elementSchema, _, _, _, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) - case Schema.Set(elementSchema, _) => - supportedElementSchema(elementSchema.asInstanceOf[Schema[Any]]) + def query[A](name: String)(implicit schema: Schema[A]): QueryCodec[A] = { + val codec = StringSchemaCodec.fromSchema[A, QueryParams]( + schema, + (qp, k, v) => qp.addQueryParam(k, v), + (qp, kvs) => qp.addQueryParams(kvs), + (qp, k) => qp.hasQueryParam(k), + (qp, k) => qp.unsafeQueryParam(k), + (qp, k) => qp.getAll(k), + (qp, k) => qp.valueCount(k), + ErrorConstructor( + param => HttpCodecError.MissingQueryParam(param), + params => HttpCodecError.MissingQueryParams(params), + validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors), + (param, value) => HttpCodecError.MalformedQueryParam(param, value), + (param, expected, actual) => HttpCodecError.InvalidQueryParamCount(param, expected, actual), + ), + isKebabCase = false, + name, + ) + HttpCodec.Query(codec) } - @tailrec - private def supportedElementSchema(elementSchema: Schema[Any]): Boolean = elementSchema match { - case Schema.Lazy(schema0) => supportedElementSchema(schema0()) - case _ => - elementSchema.isInstanceOf[Schema.Primitive[_]] || - elementSchema.isInstanceOf[Schema.Enum[_]] && elementSchema.annotations.exists(_.isInstanceOf[simpleEnum]) || - elementSchema.isInstanceOf[Schema.Record[_]] && elementSchema.asInstanceOf[Schema.Record[_]].fields.size == 1 + def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = { + val codec = StringSchemaCodec.fromSchema[A, QueryParams]( + schema, + (qp, k, v) => qp.addQueryParam(k, v), + (qp, kvs) => qp.addQueryParams(kvs), + (qp, k) => qp.hasQueryParam(k), + (qp, k) => qp.unsafeQueryParam(k), + (qp, k) => qp.getAll(k), + (qp, k) => qp.valueCount(k), + ErrorConstructor( + param => HttpCodecError.MissingQueryParam(param), + params => HttpCodecError.MissingQueryParams(params), + validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors), + (param, value) => HttpCodecError.MalformedQueryParam(param, value), + (param, expected, actual) => HttpCodecError.InvalidQueryParamCount(param, expected, actual), + ), + isKebabCase = false, + null, + ) + HttpCodec.Query(codec) } - def queryAll[A](implicit schema: Schema[A]): QueryCodec[A] = - schema match { - case _: Schema.Primitive[A] => - throw new IllegalArgumentException("Use query[A](name: String) for primitive types") - case record: Schema.Record[A] => - HttpCodec.Query(record) - case Schema.Optional(s, _) if s.isInstanceOf[Schema.Record[_]] => - HttpCodec.Query(schema) - case _ => - throw new IllegalArgumentException( - "Only case classes can be used with queryAll. Maybe you wanted to use query[A](name: String)?", - ) - } - } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala index af296b3cf..4ae3b295b 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/AtomizedCodecs.scala @@ -25,18 +25,16 @@ import zio.http.codec._ private[http] final case class AtomizedCodecs( method: Chunk[SimpleCodec[zio.http.Method, _]], path: Chunk[PathCodec[_]], - query: Chunk[Query[_, _]], + query: Chunk[Query[_]], header: Chunk[Header[_]], - headerCustom: Chunk[HeaderCustom[_]], content: Chunk[BodyCodec[_]], status: Chunk[SimpleCodec[zio.http.Status, _]], ) { self => def append(atom: Atom[_, _]): AtomizedCodecs = atom match { case path0: Path[_] => self.copy(path = path :+ path0.pathCodec) case method0: Method[_] => self.copy(method = method :+ method0.codec) - case query0: Query[_, _] => self.copy(query = query :+ query0) + case query0: Query[_] => self.copy(query = query :+ query0) case header0: Header[_] => self.copy(header = header :+ header0) - case header0: HeaderCustom[_] => self.copy(headerCustom = headerCustom :+ header0) case status0: Status[_] => self.copy(status = status :+ status0.codec) case content0: Content[_] => self.copy(content = content :+ BodyCodec.Single(content0.codec, content0.name)) @@ -50,7 +48,6 @@ private[http] final case class AtomizedCodecs( path = Array.ofDim(path.length), query = Array.ofDim(query.length), header = Array.ofDim(header.length), - headerCustom = Array.ofDim(headerCustom.length), content = Array.ofDim(content.length), status = Array.ofDim(status.length), ) @@ -62,7 +59,6 @@ private[http] final case class AtomizedCodecs( path = path.materialize, query = query.materialize, header = header.materialize, - headerCustom = headerCustom.materialize, content = content.materialize, status = status.materialize, ) @@ -75,7 +71,6 @@ private[http] object AtomizedCodecs { path = Chunk.empty, query = Chunk.empty, header = Chunk.empty, - headerCustom = Chunk.empty, content = Chunk.empty, status = Chunk.empty, ) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index e00ba2066..4d796cc21 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -169,7 +169,6 @@ private[codec] object EncoderDecoder { decodeStatus(status, inputsBuilder.status) decodeMethod(method, inputsBuilder.method) decodeHeaders(headers, inputsBuilder.header) - decodeCustomHeaders(headers, inputsBuilder.headerCustom) decodeBody(config, body, inputsBuilder.content).as(constructor(inputsBuilder)) } @@ -182,7 +181,7 @@ private[codec] object EncoderDecoder { val query = encodeQuery(config, inputs.query) val status = encodeStatus(inputs.status) val method = encodeMethod(inputs.method) - val headers = encodeHeaders(inputs.header) ++ encodeCustomHeaders(inputs.headerCustom) + val headers = encodeHeaders(inputs.header) def contentTypeHeaders = encodeContentType(inputs.content, outputTypes) val body = encodeBody(config, inputs.content, outputTypes) @@ -216,298 +215,19 @@ private[codec] object EncoderDecoder { ) private def decodeQuery(config: CodecConfig, queryParams: QueryParams, inputs: Array[Any]): Unit = - genericDecode[QueryParams, HttpCodec.Query[_, _]]( + genericDecode[QueryParams, HttpCodec.Query[_]]( queryParams, flattened.query, inputs, - (codec, queryParams) => { - val query = codec.erase - val optional = query.isOptionalSchema - val hasDefault = query.codec.defaultValue != null && query.isOptional - val default = query.codec.defaultValue - if (codec.isPrimitive) { - val name = query.nameUnsafe - val hasParam = queryParams.hasQueryParam(name) - if ( - (!hasParam || (queryParams - .unsafeQueryParam(name) == "" && !emptyStringIsValue(codec.codec.schema))) && hasDefault - ) - default - else if (!hasParam) - throw HttpCodecError.MissingQueryParam(name) - else if (queryParams.valueCount(name) != 1) - throw HttpCodecError.InvalidQueryParamCount(name, 1, queryParams.valueCount(name)) - else { - val decoded = - codec.codec.stringCodec.decode(queryParams.unsafeQueryParam(name)) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - val validationErrors = codec.codec.erasedSchema.validate(decoded)(codec.codec.erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - else decoded - } - - } else if (codec.isCollection) { - val name = query.nameUnsafe - val hasParam = queryParams.hasQueryParam(name) - - if (!hasParam) { - if (query.codec.defaultValue != null) query.codec.defaultValue - else throw HttpCodecError.MissingQueryParam(name) - } else { - val decoded = queryParams.queryParams(name).map { value => - query.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - } - if (optional) - Some( - createAndValidateCollection( - query.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], - decoded, - ), - ) - else createAndValidateCollection(query.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) - } - } else { - val recordSchema = query.codec.recordSchema - val fields = query.codec.recordFields - val hasAllParams = fields.forall { case (field, codec) => - queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional - } - if (!hasAllParams && hasDefault) default - else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - fields.collect { - case (field, codec) - if !(queryParams.hasQueryParam(field.fieldName) || field.optional || codec.isOptional) => - field.fieldName - } - } - else { - val decoded = fields.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] - if (!queryParams.hasQueryParam(field.fieldName)) { - if (field.defaultValue.isDefined) field.defaultValue.get - else throw HttpCodecError.MissingQueryParam(field.fieldName) - } else { - val values = queryParams.queryParams(field.fieldName) - val decoded = - values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedQueryParam.apply)) - createAndValidateCollection(schema, decoded) - - } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.fieldName, null) - val decoded = { - if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue - else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedQueryParam.apply) - } - validateDecoded(codec, decoded) - } - if (optional) { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } - } - } - } - }, + (codec, queryParams) => codec.erase.codec.decode(queryParams), ) - private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { - val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = schema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - } - - @tailrec - private def emptyStringIsValue(schema: Schema[_]): Boolean = { - schema match { - case value: Schema.Optional[_] => - val innerSchema = value.schema - emptyStringIsValue(innerSchema) - case _ => - schema.asInstanceOf[Schema.Primitive[_]].standardType match { - case StandardType.UnitType => true - case StandardType.StringType => true - case StandardType.BinaryType => true - case StandardType.CharType => true - case _ => false - } - } - } - - private def decodeCustomHeaders(headers: Headers, inputs: Array[Any]): Unit = - genericDecode[Headers, HttpCodec.HeaderCustom[_]]( - headers, - flattened.headerCustom, - inputs, - (header, headers) => { - val optional = header.codec.isOptionalSchema - if (header.codec.isPrimitive) { - val schema = header.erase.codec.schema - val name = header.codec.name.get - val value = headers.getUnsafe(name) - if (value ne null) { - val decoded = header.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) - case Right(value) => value - } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - else decoded - } else { - if (optional) None - else throw HttpCodecError.MissingHeader(name) - } - } else if (header.codec.isCollection) { - val name = header.codec.name.get - val values = headers.rawHeaders(name) - val decoded = values.map { value => - header.codec.stringCodec.decode(value) match { - case Left(error) => throw HttpCodecError.MalformedCustomHeader(name, error) - case Right(value) => value - } - } - if (optional) - Some( - createAndValidateCollection( - header.codec.schema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Collection[_, _]], - decoded, - ), - ) - else createAndValidateCollection(header.codec.schema.asInstanceOf[Schema.Collection[_, _]], decoded) - } else { - val recordSchema = header.codec.recordSchema - val fields = header.codec.recordFields - val hasAllParams = fields.forall { case (field, codec) => - headers.contains(field.fieldName) || field.optional || codec.isOptional - } - if (!hasAllParams) { - if (header.codec.defaultValue != null && header.codec.isOptional) header.codec.defaultValue - else - throw HttpCodecError.MissingHeaders { - fields.collect { - case (field, codec) if !(headers.contains(field.fieldName) || field.optional || codec.isOptional) => - field.fieldName - } - } - } else { - val decoded = fields.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!headers.contains(codec.name.get)) { - if (codec.defaultValue != null) codec.defaultValue - else throw HttpCodecError.MissingHeader(codec.name.get) - } else { - val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] - val values = headers.rawHeaders(codec.name.get) - val decoded = - values.map(decodeAndUnwrap(field, codec, _, HttpCodecError.MalformedCustomHeader.apply)) - createAndValidateCollection(schema, decoded) - } - case (field, codec) => - val value = headers.getUnsafe(codec.name.get) - val decoded = - if (value == null || (value == "" && !emptyStringIsValue(codec.schema))) codec.defaultValue - else decodeAndUnwrap(field, codec, value, HttpCodecError.MalformedCustomHeader.apply) - validateDecoded(codec, decoded) - } - if (optional) { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedCustomHeader( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => Some(value) - } - } - } else { - val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedCustomHeader( - s"${recordSchema.id}", - DecodeError.ReadError(Cause.empty, value), - ) - case Right(value) => - recordSchema.validate(value)(recordSchema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => value - } - } - } - } - } - }, - ) - - private def validateDecoded(codec: HttpCodec.SchemaCodec[Any], decoded: Any) = { - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded - } - - private def decodeAndUnwrap( - field: Schema.Field[_, _], - codec: HttpCodec.SchemaCodec[Any], - value: String, - ex: (String, DecodeError) => HttpCodecError, - ) = { - codec.stringCodec.decode(value) match { - case Left(error) => throw ex(codec.name.get, error) - case Right(value) => value - } - } - private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = genericDecode[Headers, HttpCodec.Header[_]]( headers, flattened.header, inputs, - (codec, headers) => - headers.get(codec.headerType.name) match { - case Some(value) => - codec.erase.headerType - .parse(value) - .getOrElse(throw HttpCodecError.MalformedTypedHeader(codec.headerType.name)) - - case None => - throw HttpCodecError.MissingHeader(codec.headerType.name) - }, + (codec, headers) => codec.headerType.fromHeadersUnsafe(headers), ) private def decodeStatus(status: Status, inputs: Array[Any]): Unit = @@ -630,159 +350,19 @@ private[codec] object EncoderDecoder { ) private def encodeQuery(config: CodecConfig, inputs: Array[Any]): QueryParams = - genericEncode[QueryParams, HttpCodec.Query[_, _]]( + genericEncode[QueryParams, HttpCodec.Query[_]]( flattened.query, inputs, QueryParams.empty, - (codec, input, queryParams) => { - val query = codec.erase - val optional = query.isOptionalSchema - val stringCodec = codec.codec.stringCodec.asInstanceOf[StringCodec[Any]] - - if (query.isPrimitive) { - val schema = codec.codec.schema - val name = query.nameUnsafe - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams.addQueryParams(name, Chunk.empty[String]) - } else { - val encoded = stringCodec.encode(input) - queryParams.addQueryParams(name, Chunk(encoded)) - } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = stringCodec.encode(input) - if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams - } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) - } - } else if (query.isCollection) { - val name = query.nameUnsafe - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams.addQueryParams( - name, - Chunk.fromIterable(values.map { value => stringCodec.encode(value) }), - ) - } else queryParams - } else if (query.isRecord) { - val value = input match { - case None => null - case Some(value) => value - case value => value - } - if (value == null) queryParams - else { - val innerSchema = query.codec.recordSchema - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var qp = queryParams - val fieldIt = query.codec.recordFields.iterator - val fieldValuesIt = fieldValues.iterator - while (fieldIt.hasNext) { - val (field, codec) = fieldIt.next() - val name = field.fieldName - val value = fieldValuesIt.next() match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - qp = qp.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.stringCodec.encode(v) - }), - ) - case _ => - val encoded = codec.stringCodec.encode(value) - qp = qp.addQueryParam(name, encoded) - } - } - qp - } - } else { - queryParams - } - }, + (codec, input, queryParams) => codec.erase.codec.encode(input, queryParams), ) - private def encodeCustomHeaders(inputs: Array[Any]): Headers = { - genericEncode[Headers, HttpCodec.HeaderCustom[_]]( - flattened.headerCustom, - inputs, - Headers.empty, - (codec, input, headers) => { - val optional = codec.codec.isOptionalSchema - val stringCodec = codec.erase.codec.stringCodec - if (codec.codec.isPrimitive) { - val name = codec.codec.name.get - val value = input - if (optional && value == None) headers - else { - val encoded = stringCodec.encode(value) - headers ++ Headers(name, encoded) - } - } else if (codec.codec.isCollection) { - val name = codec.codec.name.get - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - headers ++ Headers.FromIterable( - values.map { value => - Header.Custom(name, stringCodec.encode(value)) - }, - ) - } else headers - } else { - val recordSchema = codec.codec.recordSchema - val fields = codec.codec.recordFields - val value = input match { - case None => null - case Some(value) => value - case value => value - } - if (value == null) headers - else { - val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) - var hs = headers - val fieldIt = fields.iterator - val fieldValuesIt = fieldValues.iterator - while (fieldIt.hasNext) { - val (field, codec) = fieldIt.next() - val name = field.fieldName - val value = fieldValuesIt.next() match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - hs = hs ++ Headers.FromIterable( - values.map { v => - Header.Custom(name, codec.stringCodec.encode(v)) - }, - ) - case _ => - val encoded = codec.stringCodec.encode(value) - hs = hs ++ Headers(name, encoded) - } - } - hs - } - } - }, - ) - - } private def encodeHeaders(inputs: Array[Any]): Headers = genericEncode[Headers, HttpCodec.Header[_]]( flattened.header, inputs, Headers.empty, - (codec, input, headers) => headers ++ Headers(codec.headerType.name, codec.erase.headerType.render(input)), + (codec, input, headers) => headers ++ codec.erase.headerType.toHeaders(input), ) private def encodeStatus(inputs: Array[Any]): Option[Status] = diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala index a1b84d1cd..f86eba905 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala @@ -127,37 +127,33 @@ object HttpGen { def headersVariables(inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = inAtoms.header.collect { case mc @ MetaCodec(HttpCodec.Header(headerType, _), _) => HttpVariable( - headerType.name.capitalize, - mc.examples.values.headOption.map(e => headerType.render(e.asInstanceOf[headerType.HeaderValue])), + headerType.names.head.capitalize, + mc.examples.values.headOption.map(e => + headerType.toHeaders(e.asInstanceOf[headerType.HeaderValue]).head.renderedValue, + ), ) } def queryVariables(config: CodecConfig, inAtoms: AtomizedMetaCodecs): Seq[HttpVariable] = { - inAtoms.query.collect { - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isPrimitive => + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.values.headOption.map { ex => + recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => HttpVariable( - codec.name.get, - mc.examples.values.headOption.map((e: Any) => codec.stringCodec.encode(e)), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => - val recordSchema = (codec.schema match { - case value if value.isInstanceOf[Schema.Optional[_]] => value.asInstanceOf[Schema.Optional[Any]].schema - case _ => codec.schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.values.headOption.map { ex => - recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - codec.recordFields.zipWithIndex.map { case ((field, codec), index) => - HttpVariable( - field.name, - examples.map(values => { - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) - codec.stringCodec.encode(fieldValue) - }), - ) - } + field.name, + examples.map(values => { + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse(throw new Exception(s"No value or default value for field ${field.name}")) + codec.encode(fieldValue) + }), + ) + } }.flatten } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 9f3c47e38..c8d3abd9a 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -21,6 +21,7 @@ import zio.http.codec._ import zio.http.endpoint._ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle import zio.http.endpoint.openapi.OpenAPI.{Path, PathItem} +import zio.http.internal.StringSchemaCodec object OpenAPIGen { private val PathWildcard = "pathWildcard" @@ -101,11 +102,10 @@ object OpenAPIGen { final case class AtomizedMetaCodecs( method: Chunk[MetaCodec[SimpleCodec[Method, _]]], path: Chunk[MetaCodec[SegmentCodec[_]]], - query: Chunk[MetaCodec[HttpCodec.Query[_, _]]], + query: Chunk[MetaCodec[HttpCodec.Query[_]]], header: Chunk[MetaCodec[HttpCodec.Header[_]]], content: Chunk[MetaCodec[HttpCodec.Atom[Content, _]]], status: Chunk[MetaCodec[HttpCodec.Status[_]]], - headerCustom: Chunk[MetaCodec[HttpCodec.HeaderCustom[_]]] = Chunk.empty, ) { def append(metaCodec: MetaCodec[_]): AtomizedMetaCodecs = metaCodec match { case MetaCodec(codec: HttpCodec.Method[_], annotations) => @@ -115,12 +115,10 @@ object OpenAPIGen { ) case MetaCodec(_: SegmentCodec[_], _) => copy(path = path :+ metaCodec.asInstanceOf[MetaCodec[SegmentCodec[_]]]) - case MetaCodec(_: HttpCodec.Query[_, _], _) => - copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_, _]]]) + case MetaCodec(_: HttpCodec.Query[_], _) => + copy(query = query :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Query[_]]]) case MetaCodec(_: HttpCodec.Header[_], _) => copy(header = header :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Header[_]]]) - case MetaCodec(_: HttpCodec.HeaderCustom[_], _) => - copy(headerCustom = headerCustom :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.HeaderCustom[_]]]) case MetaCodec(_: HttpCodec.Status[_], _) => copy(status = status :+ metaCodec.asInstanceOf[MetaCodec[HttpCodec.Status[_]]]) case MetaCodec(_: HttpCodec.Content[_], _) => @@ -138,7 +136,6 @@ object OpenAPIGen { header ++ that.header, content ++ that.content, status ++ that.status, - headerCustom ++ that.headerCustom, ) def contentExamples: Map[String, OpenAPI.ReferenceOr.Or[OpenAPI.Example]] = @@ -176,7 +173,6 @@ object OpenAPIGen { header.materialize, content.materialize, status.materialize, - headerCustom.materialize, ) } @@ -188,7 +184,6 @@ object OpenAPIGen { header = Chunk.empty, content = Chunk.empty, status = Chunk.empty, - headerCustom = Chunk.empty, ) def flatten[R, A](codec: HttpCodec[R, A]): AtomizedMetaCodecs = { @@ -758,85 +753,41 @@ object OpenAPIGen { def parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = queryParams ++ pathParams ++ headerParams - def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = { - inAtoms.query.collect { - case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isPrimitive => + def queryParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = + inAtoms.query.collect { case mc @ MetaCodec(HttpCodec.Query(codec, _), _) => + val recordSchema = (codec.schema match { + case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema + case _ => codec.schema + }).asInstanceOf[Schema.Record[Any]] + val examples = mc.examples.map { case (exName, ex) => + exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) + } + codec.recordFields.zipWithIndex.map { case ((field, codec), index) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.queryParameter( - name = q.nameUnsafe, + name = field.name, description = mc.docsOpt, schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), deprecated = mc.deprecated, style = OpenAPI.Parameter.Style.Form, explode = false, allowReserved = false, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) - }, - required = mc.required && !q.isOptional, - ), - ) :: Nil - case mc @ MetaCodec(HttpCodec.Query(codec, _), _) if codec.isRecord => - val recordSchema = (codec.schema match { - case schema if schema.isInstanceOf[Schema.Optional[_]] => schema.asInstanceOf[Schema.Optional[_]].schema - case _ => codec.schema - }).asInstanceOf[Schema.Record[Any]] - val examples = mc.examples.map { case (exName, ex) => - exName -> recordSchema.deconstruct(ex)(Unsafe.unsafe) - } - codec.recordFields.zipWithIndex.map { case ((field, codec), index) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = field.name, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(codec.schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = examples.map { case (exName, values) => - val fieldValue = values(index) - .orElse(field.defaultValue) - .getOrElse( - throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), - ) - s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( - OpenAPI.Example(value = Json.Str(codec.stringCodec.encode(fieldValue))), + examples = examples.map { case (exName, values) => + val fieldValue = values(index) + .orElse(field.defaultValue) + .getOrElse( + throw new Exception(s"No value or default value found for field ${exName}_${field.name}"), ) - }, - required = mc.required, - ), - ) - - } - case mc @ MetaCodec(q @ HttpCodec.Query(codec, _), _) if codec.isCollection => - var required = false - val schema = codec.schema.asInstanceOf[Schema.Collection[_, _]] match { - case s: Schema.Sequence[_, _, _] => s.elementSchema - case _: Schema.Map[_, _] => throw new Exception("Map query parameters not supported") - case _: Schema.NonEmptyMap[_, _] => throw new Exception("Map query parameters not supported") - case s: Schema.NonEmptySequence[_, _, _] => - required = true - s.elementSchema - case s: Schema.Set[_] => s.elementSchema - } - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.queryParameter( - name = q.nameUnsafe, - description = mc.docsOpt, - schema = Some(OpenAPI.ReferenceOr.Or(JsonSchema.fromZSchema(schema))), - deprecated = mc.deprecated, - style = OpenAPI.Parameter.Style.Form, - explode = false, - allowReserved = false, - examples = mc.examples.map { case (exName, value) => - exName -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(value = Json.Str(value.toString))) + s"${exName}_${field.name}" -> OpenAPI.ReferenceOr.Or( + OpenAPI.Example(value = Json.Str(codec.encode(fieldValue))), + ) }, - required = required, + required = mc.required && !StringSchemaCodec.isOptional(field.schema), ), - ) :: Nil - } - }.flatten.toSet + ) + + } + }.flatten.toSet def pathParams: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = inAtoms.path.collect { @@ -861,35 +812,19 @@ object OpenAPIGen { .map { case mc @ MetaCodec(codec, _) => OpenAPI.ReferenceOr.Or( OpenAPI.Parameter.headerParameter( - name = mc.name.getOrElse(codec.headerType.name), + name = mc.name.getOrElse(codec.headerType.names.head), description = mc.docsOpt, definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), deprecated = mc.deprecated, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) - }, - required = mc.required, - ), - ) - } - .toSet ++ inAtoms.headerCustom - .asInstanceOf[Chunk[MetaCodec[HttpCodec.HeaderCustom[Any]]]] - // todo must handle collection and record - .map { case mc @ MetaCodec(codec, _) => - OpenAPI.ReferenceOr.Or( - OpenAPI.Parameter.headerParameter( - name = codec.codec.name.getOrElse(throw new Exception("Header parameter must have a name")), - description = mc.docsOpt, - definition = Some(OpenAPI.ReferenceOr.Or(JsonSchema.String().nullable(!mc.required))), - deprecated = mc.deprecated, - examples = mc.examples.map { case (name, value) => - name -> OpenAPI.ReferenceOr - .Or(OpenAPI.Example(codec.codec.stringCodec.encode(value).toJsonAST.toOption.get)) - }, + examples = Map.empty, +// mc.examples.map { case (name, value) => +// name -> OpenAPI.ReferenceOr.Or(OpenAPI.Example(codec.headerType.render(value).toJsonAST.toOption.get)) +// }, required = mc.required, ), ) } + .toSet def genDiscriminator(schema: Schema[_]): Option[OpenAPI.Discriminator] = { schema match { @@ -1155,7 +1090,8 @@ object OpenAPIGen { private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => - codec.headerType.name -> OpenAPI.ReferenceOr.Or( + // todo use all headers + codec.headerType.names.head -> OpenAPI.ReferenceOr.Or( OpenAPI.Header( description = mc.docsOpt, required = true, diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala index 1a9841c32..2fb9b2c44 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderGetters.scala @@ -67,7 +67,7 @@ trait HeaderGetters { self => /** Gets the raw unparsed header value */ final def rawHeader(name: CharSequence): Option[String] = headers.get(name) - final def rawHeaders(name: CharSequence): Chunk[String] = + def rawHeaders(name: CharSequence): Chunk[String] = Chunk.fromIterator( headers.iterator .filter(header => CharSequenceExtensions.equals(header.headerNameAsCharSequence, name, CaseMode.Insensitive)) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala index 255358d8c..629f497f0 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/HeaderModifier.scala @@ -39,6 +39,9 @@ trait HeaderModifier[+A] { self => final def addHeaders(headers: Headers): A = updateHeaders(_ ++ headers) + final def addHeaders(headers: Iterable[(CharSequence, CharSequence)]): A = + addHeaders(Headers.fromIterable(headers.map { case (k, v) => Header.Custom(k, v) })) + final def removeHeader(headerType: HeaderType): A = removeHeader(headerType.name) final def removeHeader(name: String): A = removeHeaders(Set(name)) diff --git a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala index 77996c462..11506c571 100644 --- a/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala +++ b/zio-http/shared/src/main/scala/zio/http/internal/QueryModifier.scala @@ -45,6 +45,11 @@ trait QueryModifier[+A] { self: QueryOps[A] with A => def addQueryParams(values: String): A = updateQueryParams(params => params ++ QueryParams.decode(values)) + def addQueryParams(queryParams: Iterable[(String, String)]): A = + updateQueryParams(params => + params ++ QueryParams(queryParams.groupBy(_._1).view.mapValues(Chunk.fromIterable(_).map(_._2)).toMap), + ) + /** * Removes the specified key from the query parameters. */ diff --git a/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala new file mode 100644 index 000000000..9aaf31fc2 --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/internal/StringSchemaCodec.scala @@ -0,0 +1,658 @@ +package zio.http.internal + +import java.time._ +import java.util.{Currency, UUID} + +import scala.annotation.tailrec +import scala.util.Try + +import zio.{Cause, Chunk, Unsafe} + +import zio.schema.codec.DecodeError +import zio.schema.validation.{Validation, ValidationError} +import zio.schema.{Schema, StandardType, TypeId} + +import zio.http.codec.HttpCodecError +import zio.http.internal.StringSchemaCodec.{PrimitiveCodec, decodeAndUnwrap, emptyStringIsValue, validateDecoded} + +private[http] final case class ErrorConstructor( + missing: String => HttpCodecError, + missingAll: Chunk[String] => HttpCodecError, + invalid: Chunk[ValidationError] => HttpCodecError, + malformed: (String, DecodeError) => HttpCodecError, + invalidCount: (String, Int, Int) => HttpCodecError, +) + +private[http] final case class StringSchemaCodec[A, Target]( + private[http] val schema: Schema[A], + private[http] val add: (Target, String, String) => Target, + private[http] val addAll: (Target, Iterable[(String, String)]) => Target, + private[http] val contains: (Target, String) => Boolean, + private[http] val unsafeGet: (Target, String) => String, + private[http] val getAll: (Target, String) => Chunk[String], + private[http] val count: (Target, String) => Int, + private[http] val error: ErrorConstructor, + private[http] val kebabCase: Boolean, + private[http] val defaultValue: A, + private[http] val isOptional: Boolean = false, + private[http] val isOptionalSchema: Boolean = false, +) { + private[http] val recordFields: Chunk[(Schema.Field[_, _], PrimitiveCodec[Any])] = { + val fields = schema match { + case record: Schema.Record[A] => + record.fields + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[A]].fields + case _ => Chunk.empty + } + fields.map(StringSchemaCodec.unlazyField).map { + case field if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val elementSchema = field.schema.asInstanceOf[Schema.Collection[_, _]] match { + case s: Schema.NonEmptySequence[_, _, _] => s.elementSchema + case s: Schema.Sequence[_, _, _] => s.elementSchema + case s: Schema.Set[_] => s.elementSchema + case _: Schema.Map[_, _] => throw new IllegalArgumentException("Maps are not supported") + case _: Schema.NonEmptyMap[_, _] => throw new IllegalArgumentException("Maps are not supported") + } + val codec = PrimitiveCodec(elementSchema).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + case field => + val codec = + PrimitiveCodec(field.annotations.foldLeft(field.schema)(_.annotate(_))).asInstanceOf[PrimitiveCodec[Any]] + (StringSchemaCodec.mapFieldName(field, kebabCase), codec) + } + } + + private[http] val recordSchema: Schema.Record[Any] = schema match { + case record: Schema.Record[_] => + record.asInstanceOf[Schema.Record[Any]] + case s: Schema.Optional[_] if s.schema.isInstanceOf[Schema.Record[_]] => + s.schema.asInstanceOf[Schema.Record[Any]] + case _ => null + } + + private def createAndValidateCollection(schema: Schema.Collection[_, _], decoded: Chunk[Any]) = { + val collection = schema.fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = schema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + collection + } + + private[http] def decode(target: Target): A = { + val optional = isOptionalSchema + val hasDefault = defaultValue != null && isOptional + val default = defaultValue + val hasAllParams = recordFields.forall { case (field, codec) => + contains(target, field.fieldName) || field.optional || codec.isOptional + } + if (!hasAllParams && hasDefault) default + else if (!hasAllParams) { + throw error.missingAll { + recordFields.collect { + case (field, codec) if !(contains(target, field.fieldName) || field.optional || codec.isOptional) => + field.fieldName + } + } + } else { + val decoded = recordFields.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + val schema = field.schema.asInstanceOf[Schema.Collection[_, _]] + if (!contains(target, field.fieldName)) { + if (field.defaultValue.isDefined) field.defaultValue.get + else throw error.missing(field.fieldName) + } else { + val values = getAll(target, field.fieldName) + val decoded = + values.map(decodeAndUnwrap(field, codec, _, error.malformed)) + createAndValidateCollection(schema, decoded) + + } + case (field, codec) => + val count0 = count(target, field.fieldName) + if (count0 > 1) throw error.invalidCount(field.fieldName, 1, count0) + val value = unsafeGet(target, field.fieldName) + val decoded = { + if (value == null || (value == "" && !emptyStringIsValue(codec.schema) && codec.isOptional)) + codec.defaultValue + else decodeAndUnwrap(field, codec, value, error.malformed.apply) + } + validateDecoded(codec, decoded, error) + } + if (optional) { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => Some(value).asInstanceOf[A] + } + } + } else { + val constructed = recordSchema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw error.malformed( + s"${recordSchema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + recordSchema.validate(value)(recordSchema) match { + case errors if errors.nonEmpty => throw error.invalid(errors) + case _ => value.asInstanceOf[A] + } + } + } + } + + } + + private[http] def encode(input: A, target: Target): Target = { + val fields = recordFields + val value = input.asInstanceOf[Any] match { + case None => null + case it: Iterable[_] if it.isEmpty => null + case Some(value) => value + case value => value + } + if (value == null) target + else { + val fieldValues = recordSchema.deconstruct(value)(Unsafe.unsafe) + var target0 = target + val fieldIt = fields.iterator + val fieldValuesIt = fieldValues.iterator + while (fieldIt.hasNext) { + val (field, codec) = fieldIt.next() + val name = field.fieldName + val value = fieldValuesIt.next() match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + target0 = addAll(target0, values.map { v => (name, codec.encode(v)) }) + case _ => + val encoded = codec.encode(value) + target0 = add(target0, name, encoded) + } + } + target0 + } + } + + private[http] def optional: StringSchemaCodec[Option[A], Target] = + StringSchemaCodec.fromSchema( + schema.optional, + add, + addAll, + contains, + unsafeGet, + getAll, + count, + error, + kebabCase, + null, + ) +} + +object StringSchemaCodec { + private[http] def unlazyField(field: Schema.Field[_, _]): Schema.Field[_, _] = field match { + case f if f.schema.isInstanceOf[Schema.Lazy[_]] => + Schema.Field( + f.name, + f.schema.asInstanceOf[Schema.Lazy[_]].schema.asInstanceOf[Schema[Any]], + f.annotations, + f.validation.asInstanceOf[Validation[Any]], + f.get.asInstanceOf[Any => Any], + f.set.asInstanceOf[(Any, Any) => Any], + ) + case f => f + } + private[http] def defaultValue[A](schema: Schema[A]): A = + if (schema.isInstanceOf[Schema.Collection[_, _]]) { + Try(schema.asInstanceOf[Schema.Collection[A, _]].empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + } else { + schema.defaultValue match { + case Right(value) => value + case Left(_) => + schema match { + case _: Schema.Optional[_] => None.asInstanceOf[A] + case collection: Schema.Collection[A, _] => + Try(collection.empty).fold( + _ => null.asInstanceOf[A], + identity, + ) + case _ => null.asInstanceOf[A] + } + } + } + + private[http] def isOptional(schema: Schema[_]): Boolean = schema match { + case _: Schema.Optional[_] => + true + case record: Schema.Record[_] => + record.fields.forall(_.optional) || record.defaultValue.isRight + case d: Schema.Collection[_, _] => + val bool = Try(d.empty).isSuccess || d.defaultValue.isRight + bool + case _ => + false + } + + private[http] def isOptionalSchema(schema: Schema[_]): Boolean = + schema match { + case _: Schema.Optional[_] => true + case s: Schema.Transform[_, _, _] if s.schema.isInstanceOf[Schema.Optional[_]] => true + case _ => false + } + + private[http] final case class PrimitiveCodec[A]( + private[http] val schema: Schema[A], + ) { + + val defaultValue: A = + StringSchemaCodec.defaultValue(schema) + + private[http] val isOptional: Boolean = + StringSchemaCodec.isOptional(schema) + + private[http] val isOptionalSchema: Boolean = + StringSchemaCodec.isOptionalSchema(schema) + + private[http] val encode: A => String = + PrimitiveCodec.primitiveSchemaEncoder(schema) + + private[http] def optional = copy(schema.optional) + + private[http] val decode: String => A = + PrimitiveCodec.primitiveSchemaDecoder(schema) + + private[http] def validate(value: A) = + schema.validate(value)(schema) + + } + + object PrimitiveCodec { + + private[http] def primitiveSchemaDecoder[A](schema: Schema[A]): String => A = schema match { + case Schema.Optional(schema, _) => + primitiveSchemaDecoder(schema).andThen(Some(_)).asInstanceOf[String => A] + case Schema.Transform(schema, f, _, _, _) => + primitiveSchemaDecoder(schema).andThen { + f(_) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => value + } + }.asInstanceOf[String => A] + case Schema.Primitive(standardType, _) => + parsePrimitive(standardType.asInstanceOf[StandardType[Any]]).asInstanceOf[String => A] + case Schema.Lazy(schema0) => + primitiveSchemaDecoder(schema0()).asInstanceOf[String => A] + case _ => throw new IllegalArgumentException(s"Unsupported schema $schema") + } + + private[http] def primitiveSchemaEncoder[A](schema: Schema[A]): A => String = schema match { + case Schema.Optional(schema, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => if (a.isInstanceOf[None.type]) null else innerEncoder(a.asInstanceOf[Some[Any]].get) + case Schema.Transform(schema, f, _, _, _) => + val innerEncoder: Any => String = primitiveSchemaEncoder(schema.asInstanceOf[Schema[Any]]) + (a: A) => + f.asInstanceOf[Any => Either[String, Any]](a.asInstanceOf[Any]) match { + case Left(value) => throw new IllegalArgumentException(value) + case Right(value) => innerEncoder(value) + } + case Schema.Lazy(schema0) => + primitiveSchemaEncoder(schema0()).asInstanceOf[A => String] + case Schema.Primitive(_, _) => + (a: A) => a.toString + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema") + } + } + + private def decodeAndUnwrap( + field: Schema.Field[_, _], + codec: PrimitiveCodec[Any], + value: String, + ex: (String, DecodeError) => HttpCodecError, + ) = + try codec.decode(value) + catch { + case err: DecodeError => throw ex(field.fieldName, err) + } + + private def validateDecoded(codec: PrimitiveCodec[Any], decoded: Any, error: ErrorConstructor) = { + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw error.invalid(validationErrors) + decoded + } + + @tailrec + private def emptyStringIsValue(schema: Schema[_]): Boolean = { + schema match { + case value: Schema.Optional[_] => + val innerSchema = value.schema + emptyStringIsValue(innerSchema) + case _ => + schema.asInstanceOf[Schema.Primitive[_]].standardType match { + case StandardType.UnitType => true + case StandardType.StringType => true + case StandardType.BinaryType => true + case StandardType.CharType => true + case _ => false + } + } + } + private[http] def mapFieldName(field: Schema.Field[_, _], kebabCase: Boolean): Schema.Field[_, _] = { + Schema.Field( + if (!kebabCase) field.fieldName else camelToKebab(field.fieldName), + field.annotations.foldLeft(field.schema)(_ annotate _).asInstanceOf[Schema[Any]], + field.annotations, + field.validation.asInstanceOf[Validation[Any]], + field.get.asInstanceOf[Any => Any], + field.set.asInstanceOf[(Any, Any) => Any], + ) + } + + @tailrec + private[http] def fromSchema[A, Target]( + schema: Schema[A], + add: (Target, String, String) => Target, + addAll: (Target, Iterable[(String, String)]) => Target, + contains: (Target, String) => Boolean, + unsafeGet: (Target, String) => String, + unsafeGetAll: (Target, String) => Chunk[String], + count: (Target, String) => Int, + error: ErrorConstructor, + isKebabCase: Boolean, + name: String, + ): StringSchemaCodec[A, Target] = { + val defaultValue = StringSchemaCodec.defaultValue(schema) + val isOptional = StringSchemaCodec.isOptional(schema) + val isOptionalSchema = StringSchemaCodec.isOptionalSchema(schema) + + def stringSchemaCodec(schema: Schema[Any]): StringSchemaCodec[A, Target] = + StringSchemaCodec( + schema.asInstanceOf[Schema[A]], + add, + addAll, + contains, + unsafeGet, + unsafeGetAll, + count, + error, + isKebabCase, + defaultValue, + isOptional, + isOptionalSchema, + ) + schema match { + case s @ Schema.Primitive(_, _) => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s @ Schema.Optional(schema, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case s if s.isInstanceOf[Schema.Record[_]] => stringSchemaCodec(schema) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case s @ Schema.Transform(schema, _, _, _, _) => + schema match { + case _: Schema.Collection[_, _] | _: Schema.Primitive[_] => + stringSchemaCodec(recordSchema(s.asInstanceOf[Schema[Any]], name)) + case _: Schema.Record[_] => stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => throw new IllegalArgumentException(s"Unsupported schema $s") + } + case Schema.Lazy(schema0) => + fromSchema( + schema0().asInstanceOf[Schema[A]], + add, + addAll, + contains, + unsafeGet, + unsafeGetAll, + count, + error, + isKebabCase, + name, + ) + case _: Schema.Collection[_, _] => + stringSchemaCodec(recordSchema(schema.asInstanceOf[Schema[Any]], name)) + case s: Schema.Record[_] => + stringSchemaCodec(s.asInstanceOf[Schema[Any]]) + case _ => + throw new IllegalArgumentException(s"Unsupported schema $schema") + + } + } + + private def recordSchema[A](s: Schema[A], name: String): Schema[A] = Schema.CaseClass1[A, A]( + TypeId.Structural, + Schema.Field(name, s, Chunk.empty, Validation.succeed, identity, (_, v) => v), + identity, + ) + + private def parsePrimitive(standardType: StandardType[_]): String => Any = + standardType match { + case StandardType.UnitType => + val result = "" + (_: String) => result + case StandardType.StringType => + (s: String) => s + case StandardType.BoolType => + (s: String) => + s.toLowerCase match { + case "true" | "on" | "yes" | "1" => true + case "false" | "off" | "no" | "0" => false + case _ => throw DecodeError.ReadError(Cause.fail(new Exception("Invalid boolean value")), s) + } + case StandardType.ByteType => + (s: String) => + try { + s.toByte + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ShortType => + (s: String) => + try { + s.toShort + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.IntType => + (s: String) => + try { + s.toInt + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LongType => + (s: String) => + try { + s.toLong + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.FloatType => + (s: String) => + try { + s.toFloat + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DoubleType => + (s: String) => + try { + s.toDouble + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BinaryType => + val result = DecodeError.UnsupportedSchema(Schema.Primitive(standardType), "TextCodec") + (_: String) => throw result + case StandardType.CharType => + (s: String) => s.charAt(0) + case StandardType.UUIDType => + (s: String) => + try { + UUID.fromString(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigDecimalType => + (s: String) => + try { + BigDecimal(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.BigIntegerType => + (s: String) => + try { + BigInt(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DayOfWeekType => + (s: String) => + try { + DayOfWeek.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthType => + (s: String) => + try { + Month.valueOf(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.MonthDayType => + (s: String) => + try { + MonthDay.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.PeriodType => + (s: String) => + try { + Period.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearType => + (s: String) => + try { + Year.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.YearMonthType => + (s: String) => + try { + YearMonth.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneIdType => + (s: String) => + try { + ZoneId.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZoneOffsetType => + (s: String) => + try { + ZoneOffset.of(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.DurationType => + (s: String) => + try { + java.time.Duration.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.InstantType => + (s: String) => + try { + Instant.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateType => + (s: String) => + try { + LocalDate.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalTimeType => + (s: String) => + try { + LocalTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.LocalDateTimeType => + (s: String) => + try { + LocalDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetTimeType => + (s: String) => + try { + OffsetTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.OffsetDateTimeType => + (s: String) => + try { + OffsetDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.ZonedDateTimeType => + (s: String) => + try { + ZonedDateTime.parse(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + case StandardType.CurrencyType => + (s: String) => + try { + Currency.getInstance(s) + } catch { + case e: Exception => throw DecodeError.ReadError(Cause.fail(e), e.getMessage) + } + } + + private def camelToKebab(s: String): String = + if (s.isEmpty) "" + else if (s.head.isUpper) s.head.toLower.toString + camelToKebab(s.tail) + else if (s.contains('-')) s + else + s.foldLeft("") { (acc, c) => + if (c.isUpper) acc + "-" + c.toLower + else acc + c + } +}