Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: Support enums in paths #3889

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ object BasicGenerator {
JsonSerdeLib.Circe
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
doc = doc,
targetScala3 = targetScala3,
queryParamRefs = queryParamRefs,
queryOrPathParamRefs = queryOrPathParamRefs,
jsonSerdeLib = normalisedJsonLib,
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
Expand Down Expand Up @@ -146,15 +146,18 @@ object BasicGenerator {
"""
|case class CommaSeparatedValues[T](values: List[T])
|case class ExplodedValues[T](values: List[T])
|trait QueryParamSupport[T] {
|trait ExtraParamSupport[T] {
| def decode(s: String): sttp.tapir.DecodeResult[T]
| def encode(t: T): String
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
|implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(support.decode)(support.encode)
|}
|implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
|implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
| sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
|}
Expand All @@ -169,10 +172,6 @@ object BasicGenerator {
| case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
| }(_.map(_.values.map(support.encode).mkString(",")))
|}
|implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now provide a Codec[String, T, sttp.tapir.CodecFormat.TextPlain], from which we can derive a Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain], the final implicit derivation is provided by makeExplodedQuerySeqCodecFromListSeq and this is now redundant (it also causes explicit resolution conflicts in scala 3 if left in, although scala 2 seems happy enough with the ambiguity for some reason).

| sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
| .mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
|}
|implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
| support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ClassDefinitionGenerator {
def classDefs(
doc: OpenapiDocument,
targetScala3: Boolean = false,
queryParamRefs: Set[String] = Set.empty,
queryOrPathParamRefs: Set[String] = Set.empty,
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib = JsonSerdeLib.Circe,
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
Expand All @@ -25,10 +25,10 @@ class ClassDefinitionGenerator {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
val adtInheritanceMap: Map[String, Seq[String]] = mkMapParentsByChild(allOneOfSchemas)
val generatesQueryParamEnums = enumsDefinedOnEndpointParams ||
val generatesQueryOrPathParamEnums = enumsDefinedOnEndpointParams ||
allSchemas
.collect { case (name, _: OpenapiSchemaEnum) => name }
.exists(queryParamRefs.contains)
.exists(queryOrPathParamRefs.contains)

def fetchJsonParamRefs(initialSet: Set[String], toCheck: Seq[OpenapiSchemaType]): Set[String] = toCheck match {
case Nil => initialSet
Expand All @@ -41,7 +41,7 @@ class ClassDefinitionGenerator {
)

val adtTypes = adtInheritanceMap.flatMap(_._2).toSeq.distinct.map(name => s"sealed trait $name").mkString("", "\n", "\n")
val enumQuerySerdeHelper = if (!generatesQueryParamEnums) "" else enumQuerySerdeHelperDefn(targetScala3)
val enumSerdeHelper = if (!generatesQueryOrPathParamEnums) "" else enumSerdeHelperDefn(targetScala3)
val schemas = SchemaGenerator.generateSchemas(doc, allSchemas, fullModelPath, jsonSerdeLib, maxSchemasPerFile)
val jsonSerdes = JsonSerdeGenerator.serdeDefs(
doc,
Expand All @@ -58,13 +58,13 @@ class ClassDefinitionGenerator {
case (name, obj: OpenapiSchemaObject) =>
generateClass(allSchemas, name, obj, allTransitiveJsonParamRefs, adtInheritanceMap, jsonSerdeLib, targetScala3)
case (name, obj: OpenapiSchemaEnum) =>
EnumGenerator.generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
EnumGenerator.generateEnum(name, obj, targetScala3, queryOrPathParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema)
case (_, _: OpenapiSchemaOneOf) => Nil
case (n, x) => throw new NotImplementedError(s"Only objects, enums and maps supported! (for $n found ${x})")
})
.map(_.mkString("\n"))
val helpers = (enumQuerySerdeHelper + adtTypes).linesIterator
val helpers = (enumSerdeHelper + adtTypes).linesIterator
.filterNot(_.forall(_.isWhitespace))
.mkString("\n")
// Json serdes & schemas live in separate files from the class defns
Expand Down Expand Up @@ -97,14 +97,14 @@ class ClassDefinitionGenerator {
.groupBy(_._1)
.mapValues(_.map(_._2))

private def enumQuerySerdeHelperDefn(targetScala3: Boolean): String = {
private def enumSerdeHelperDefn(targetScala3: Boolean): String = {
if (targetScala3)
"""
|def enumMap[E: enumextensions.EnumMirror]: Map[String, E] =
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
Expand All @@ -121,12 +121,12 @@ class ClassDefinitionGenerator {
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|def extraCodecSupport[T: enumextensions.EnumMirror]: ExtraParamSupport[T] =
| EnumExtraParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|""".stripMargin
else
"""
|case class EnumQueryParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
Expand All @@ -142,8 +142,8 @@ class ClassDefinitionGenerator {
| )
| def encode(t: T): String = t.entryName
|}
|def queryCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): QueryParamSupport[T] =
| EnumQueryParamSupport(enumName, T)
|def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
| EnumExtraParamSupport(enumName, T)
|""".stripMargin
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class GeneratedEndpoints(
}
case class EndpointDefs(
endpointDecls: Map[Option[String], String],
queryParamRefs: Set[String],
queryOrPathParamRefs: Set[String],
jsonParamRefs: Set[String],
enumsDefinedOnEndpointParams: Boolean
)
Expand All @@ -58,7 +58,7 @@ class EndpointGenerator {
jsonSerdeLib: JsonSerdeLib
): EndpointDefs = {
val components = Option(doc.components).flatten
val GeneratedEndpoints(endpointsByFile, queryParamRefs, jsonParamRefs, definesEnumQueryParam) =
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) =
doc.paths
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib))
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _)
Expand All @@ -77,7 +77,7 @@ class EndpointGenerator {
|$allEP
|""".stripMargin
}.toMap
EndpointDefs(endpointDecls, queryParamRefs, jsonParamRefs, definesEnumQueryParam)
EndpointDefs(endpointDecls, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam)
}

private[codegen] def generatedEndpoints(
Expand Down Expand Up @@ -119,8 +119,8 @@ class EndpointGenerator {
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")

val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
val queryOrPathParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" => queryParam.schema }
.collect { case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped }
.toSet
val jsonParamRefs = (m.requestBody.toSeq.flatMap(_.content.map(c => (c.contentType, c.schema))) ++
Expand All @@ -143,7 +143,7 @@ class EndpointGenerator {
.toSet
(
(maybeTargetFileName, GeneratedEndpoint(name, definition, maybeLocalEnums)),
(queryParamRefs, jsonParamRefs),
(queryOrPathParamRefs, jsonParamRefs),
maybeLocalEnums.isDefined
)
}
Expand Down Expand Up @@ -215,12 +215,12 @@ class EndpointGenerator {
)(implicit location: Location): (String, Option[String]) = {
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
val queryParamRefs = if (param.in == "query") Set(enumName) else Set.empty[String]
val enumParamRefs = if (param.in == "query" || param.in == "path") Set(enumName) else Set.empty[String]
val enumDefn = EnumGenerator.generateEnum(
enumName,
e,
targetScala3,
queryParamRefs,
enumParamRefs,
jsonSerdeLib,
Set.empty
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ object EnumGenerator {
val maybeCompanion =
if (queryParamRefs contains name) {
def helperImpls =
s""" given enumCodecSupport${name.capitalize}: QueryParamSupport[$name] =
| queryCodecSupport[$name]""".stripMargin
s""" given enumCodecSupport${name.capitalize}: ExtraParamSupport[$name] =
| extraCodecSupport[$name]""".stripMargin
s"""
|object $name {
|$helperImpls
Expand Down Expand Up @@ -52,8 +52,8 @@ object EnumGenerator {
val maybeQueryCodecDefn =
if (queryParamRefs contains name) {
s"""
| implicit val enumCodecSupport${name.capitalize}: QueryParamSupport[$name] =
| queryCodecSupport[$name]("${name}", ${name})""".stripMargin
| implicit val enumCodecSupport${name.capitalize}: ExtraParamSupport[$name] =
| extraCodecSupport[$name]("${name}", ${name})""".stripMargin
} else ""
s"""
|sealed trait $name extends enumeratum.EnumEntry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
.classDefs(doc, true, jsonParamRefs = Set("Test"))
.map(concatted)
val resWithQueryParamCodec = gen
.classDefs(doc, true, queryParamRefs = Set("Test"), jsonParamRefs = Set("Test"))
.classDefs(doc, true, queryOrPathParamRefs = Set("Test"), jsonParamRefs = Set("Test"))
.map(concatted)
// can't just check whether these compile, because our tests only run on scala 2.12 - so instead just eyeball it...
res shouldBe Some("""enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec {
Expand All @@ -304,7 +304,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
| Map.from(
| for e <- enumextensions.EnumMirror[E].values yield e.name.toUpperCase -> e
| )
|case class EnumQueryParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends QueryParamSupport[T] {
|case class EnumExtraParamSupport[T: enumextensions.EnumMirror](eMap: Map[String, T]) extends ExtraParamSupport[T] {
| // Case-insensitive mapping
| def decode(s: String): sttp.tapir.DecodeResult[T] =
| scala.util
Expand All @@ -321,11 +321,11 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
| )
| def encode(t: T): String = t.name
|}
|def queryCodecSupport[T: enumextensions.EnumMirror]: QueryParamSupport[T] =
| EnumQueryParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|def extraCodecSupport[T: enumextensions.EnumMirror]: ExtraParamSupport[T] =
| EnumExtraParamSupport(enumMap[T](using enumextensions.EnumMirror[T]))
|object Test {
| given enumCodecSupportTest: QueryParamSupport[Test] =
| queryCodecSupport[Test]
| given enumCodecSupportTest: ExtraParamSupport[Test] =
| extraCodecSupport[Test]
|}
|enum Test derives org.latestbit.circe.adt.codec.JsonTaggedAdt.PureCodec, enumextensions.EnumMirror {
| case enum1, enum2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,15 @@ object TestHelpers {
| $ref: '#/components/schemas/Test'
| post:
| responses: {}
| /pathTest/{test2}:
| parameters:
| - name: test2
| in: path
| required: true
| schema:
| $ref: '#/components/schemas/Test2'
| post:
| responses: {}
|
|components:
| schemas:
Expand All @@ -535,6 +544,12 @@ object TestHelpers {
| enum:
| - paperback
| - hardback
| Test2:
| title: Test
| type: string
| enum:
| - paperback
| - hardback
|""".stripMargin

val enumQueryParamDocs = OpenapiDocument(
Expand All @@ -555,7 +570,24 @@ object TestHelpers {
)
),
parameters = Seq(
Resolved(OpenapiParameter("test", "query", None, None, OpenapiSchemaRef("#/components/schemas/Test")))
Resolved(OpenapiParameter("test", "query", Some(false), None, OpenapiSchemaRef("#/components/schemas/Test")))
)
),
OpenapiPath(
"/pathTest/{test2}",
Seq(
OpenapiPathMethod(
methodType = "post",
parameters = Seq(),
responses = Seq(),
requestBody = None,
summary = None,
tags = None,
operationId = None
)
),
parameters = Seq(
Resolved(OpenapiParameter("test2", "path", Some(true), None, OpenapiSchemaRef("#/components/schemas/Test2")))
)
)
),
Expand All @@ -566,6 +598,11 @@ object TestHelpers {
"string",
Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")),
false
),
"Test2" -> OpenapiSchemaEnum(
"string",
Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")),
false
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ class ModelParserSpec extends AnyFlatSpec with Matchers with Checkers {
res shouldBe Right(
OpenapiSchemaEnum("string", Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")), false)
)
parser
.parse(TestHelpers.enumQueryParamYaml)
.leftMap(err => err: Error)
.flatMap(_.as[OpenapiDocument]) shouldBe Right(
TestHelpers.enumQueryParamDocs
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was previously missing

}

it should "parse endpoint with defaults" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ object TapirGeneratedEndpoints {

case class CommaSeparatedValues[T](values: List[T])
case class ExplodedValues[T](values: List[T])
trait QueryParamSupport[T] {
trait ExtraParamSupport[T] {
def decode(s: String): sttp.tapir.DecodeResult[T]
def encode(t: T): String
}
implicit def makeQueryCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
implicit def makePathCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[String, T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.string.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], T, sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHead[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(support.decode)(support.encode)
}
implicit def makeQueryOptCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
implicit def makeQueryOptCodecFromSupport[T](implicit support: ExtraParamSupport[T]): sttp.tapir.Codec[List[String], Option[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.listHeadOption[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(maybeV => DecodeResult.sequence(maybeV.toSeq.map(support.decode)).map(_.headOption))(_.map(support.encode))
}
Expand All @@ -38,10 +41,6 @@ object TapirGeneratedEndpoints {
case Some(values) => DecodeResult.sequence(values.split(',').toSeq.map(e => support.rawDecode(List(e)))).map(r => Some(CommaSeparatedValues(r.toList)))
}(_.map(_.values.map(support.encode).mkString(",")))
}
implicit def makeExplodedQuerySeqCodecFromSupport[T](implicit support: QueryParamSupport[T]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
sttp.tapir.Codec.list[String, String, sttp.tapir.CodecFormat.TextPlain]
.mapDecode(values => DecodeResult.sequence(values.map(support.decode)).map(s => ExplodedValues(s.toList)))(_.values.map(support.encode))
}
implicit def makeExplodedQuerySeqCodecFromListSeq[T](implicit support: sttp.tapir.Codec[List[String], List[T], sttp.tapir.CodecFormat.TextPlain]): sttp.tapir.Codec[List[String], ExplodedValues[T], sttp.tapir.CodecFormat.TextPlain] = {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}
Expand Down
Loading
Loading