Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get with narrow for sum types #494

Merged
merged 22 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec {
_ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute
invoice <- // invoice is of type InvoiceWithDiscriminatorName
get(invoiceTable)(
((InvoiceWithDiscriminatorName.unpaid) >>> (InvoiceWithDiscriminatorName.Unpaid.id)).partitionKey === "1"
(InvoiceWithDiscriminatorName.unpaid >>> InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1"
).execute.absolve
} yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1"))
}
Expand All @@ -81,7 +81,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec {
for {
_ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute
invoice <- get(invoiceTable)( // invoice is of type InvoiceWithDiscriminatorName.Unpaid
(InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1"
InvoiceWithDiscriminatorName.Unpaid.id.partitionKey === "1"
).execute.absolve
} yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1"))
}
Expand Down
132 changes: 132 additions & 0 deletions dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package zio.dynamodb

import zio.dynamodb.DynamoDBQuery.{ getItem, put }
import zio.Scope
import zio.test.Spec
import zio.test.Assertion.fails
import zio.test.{ assert, assertTrue }
import zio.test.TestEnvironment
import zio.test.Assertion.{ equalTo, isLeft, isRight }
import zio.test.TestAspect
import zio.schema.Schema
import zio.schema.DeriveSchema
import zio.schema.annotation.discriminatorName
import zio.dynamodb.DynamoDBQuery.getWithNarrow
import zio.dynamodb.DynamoDBError.ItemError

object TypeSafeApiNarrowSpec extends DynamoDBLocalSpec {

object dynamo {
@discriminatorName("invoiceType")
sealed trait Invoice {
def id: String
}
object Invoice {
final case class Unrelated(id: Int)
object Unrelated {
implicit val schema: Schema.CaseClass1[Int, Unrelated] = DeriveSchema.gen[Unrelated]
val id = ProjectionExpression.accessors[Unrelated]
}
final case class Unpaid(id: String) extends Invoice
object Unpaid {
implicit val schema: Schema.CaseClass1[String, Unpaid] = DeriveSchema.gen[Unpaid]
val id = ProjectionExpression.accessors[Unpaid]
}
final case class Paid(id: String, amount: Int) extends Invoice
object Paid {
implicit val schema: Schema.CaseClass2[String, Int, Paid] = DeriveSchema.gen[Paid]
val (id, amount) = ProjectionExpression.accessors[Paid]
}
implicit val schema: Schema.Enum2[Unpaid, Paid, Invoice] =
DeriveSchema.gen[Invoice]
val (unpaid, paid) = ProjectionExpression.accessors[Invoice]
}

}

override def spec: Spec[Environment with TestEnvironment with Scope, Any] =
suite("TypeSafeApiNarrowSpec")(
topLevelSumTypeNarrowSuite,
narrowSuite
) @@ TestAspect.nondeterministic

val topLevelSumTypeNarrowSuite = suite("for top level Invoice sum type with @discriminatorName annotation")(
test("getWithNarrow succeeds in narrowing an Unpaid Invoice instance to Unpaid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] =
dynamo.Invoice.Unpaid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute
item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute

unpaid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve
} yield {
val unpaid2: dynamo.Invoice.Unpaid = unpaid
val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Unpaid"))
assertTrue(unpaid2 == dynamo.Invoice.Unpaid("1") && ensureDiscriminatorPresent)
}
}
},
test("getWithNarrow succeeds in narrowing an Paid Invoice instance to Paid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] =
dynamo.Invoice.Paid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute
item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute

paid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve
} yield {
val paid2: dynamo.Invoice.Paid = paid
val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Paid", "amount" -> 42))
assertTrue(paid2 == dynamo.Invoice.Paid("1", 42) && ensureDiscriminatorPresent)
}
}
},
test("getWithNarrow fails in narrowing an Unpaid Invoice instance to Paid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] =
dynamo.Invoice.Paid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute
exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve.exit
} yield assert(exit)(
fails(equalTo(ItemError.DecodingError("failed to narrow - found type Unpaid but expected type Paid")))
)
}
},
test("getWithNarrow fails in narrowing an Paid Invoice instance to Unpaid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] =
dynamo.Invoice.Unpaid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute
exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve.exit
} yield assert(exit)(
fails(equalTo(ItemError.DecodingError("failed to narrow - found type Paid but expected type Unpaid")))
)
}
}
)

val narrowSuite = suite("narrow suite")(
test("narrow Paid instance to Paid for success and failure") {
val invoice: dynamo.Invoice = dynamo.Invoice.Paid("1", 1)
val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice)
val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice)

assert(valid)(isRight) && assert(invalid)(
isLeft(equalTo("failed to narrow - found type Paid but expected type Unpaid"))
)
},
test("narrow Unpaid instance to Unpaid for success and failure") {
val invoice: dynamo.Invoice = dynamo.Invoice.Unpaid("1")
val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice)
val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice)

assert(valid)(isRight) && assert(invalid)(
isLeft(equalTo("failed to narrow - found type Unpaid but expected type Paid"))
)
}
)
}
51 changes: 50 additions & 1 deletion dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,56 @@ object DynamoDBQuery {
): DynamoDBQuery[From, Either[ItemError, From]] =
get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From])

/**
* It is common practice to save top level sum types to DynamoDB and often we want to retrieve them back as the subtype.
* `getWithNarrow` does a `get` with a safe narrow operation from type `From` to `To`.
* If the narrow fails it returns a Decoding error with details of the cast failure in the message.
* Requires implicit schemas in scope which ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype.
*/
def getWithNarrow[From: Schema.Enum, To <: From: Schema.Record](tableName: String)(
primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To]
): DynamoDBQuery[From, Either[ItemError, To]] = {

def getWithNarrowedKeyCondExpr[From: Schema.Enum, To <: From](tableName: String)(
primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To]
): DynamoDBQuery[From, Either[ItemError, From]] =
get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From])

getWithNarrowedKeyCondExpr[From, To](tableName)(primaryKeyExpr).map {
case Right(found) =>
narrow[From, To](found).left.map(DynamoDBError.ItemError.DecodingError.apply)

case Left(error) => Left(error)
}
}

// Safely narrows `a: From` to subtype type `To` and requires that there are implicit schemas in scope which
// ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype.
private[dynamodb] def narrow[From: Schema.Enum, To <: From: Schema.Record](
a: From
): Either[String, To] = {
val fromEnumSchema: Schema.Enum[From] = implicitly[Schema.Enum[From]]
val toSchema: Schema.Record[To] = implicitly[Schema.Record[To]]
val o: Option[Schema.Case[From, _]] = fromEnumSchema.caseOf(a)

o match {
case Some(c @ Schema.Case(_, Schema.Lazy(s), _, _, _, _)) =>
s() == toSchema match {
case true => Right(a.asInstanceOf[To])
case _ =>
Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}")
}
case Some(c) =>
c.schema == toSchema match {
case true => Right(a.asInstanceOf[To])
case _ => Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}")
}
case None =>
// this should never happen as we have a type level proof
Left(s"failed to narrow - argument is not a subtype of ${fromEnumSchema.id.name}")
}
}

private def get[A: Schema](
tableName: String,
key: PrimaryKey,
Expand Down Expand Up @@ -546,7 +596,6 @@ object DynamoDBQuery {
/**
* when executed will return a Tuple of {{{Either[String,(Chunk[A], LastEvaluatedKey)]}}}
*/

def scanSome[A: Schema](
tableName: String,
limit: Int
Expand Down
Loading