From 185da863c94051565039a2e8b94f00e295be8797 Mon Sep 17 00:00:00 2001 From: Avinder Bahra Date: Wed, 26 Jul 2023 08:01:34 +0100 Subject: [PATCH] fix KeyConditionExpression rendering for names --- .../src/it/scala/zio/dynamodb/LiveSpec.scala | 78 ++++++++++++++++ .../zio/dynamodb/KeyConditionExpression.scala | 86 +++++++++--------- .../zio/dynamodb/AliasMapRenderSpec.scala | 91 ++++++++++++------- 3 files changed, 176 insertions(+), 79 deletions(-) diff --git a/dynamodb/src/it/scala/zio/dynamodb/LiveSpec.scala b/dynamodb/src/it/scala/zio/dynamodb/LiveSpec.scala index cc45b899f..d0eb54432 100644 --- a/dynamodb/src/it/scala/zio/dynamodb/LiveSpec.scala +++ b/dynamodb/src/it/scala/zio/dynamodb/LiveSpec.scala @@ -8,6 +8,8 @@ import software.amazon.awssdk.services.dynamodb.model.{ DynamoDbException, Idemp import zio.dynamodb.UpdateExpression.Action.SetAction import zio.dynamodb.UpdateExpression.SetOperand import zio.dynamodb.PartitionKeyExpression.PartitionKey +import zio.dynamodb.KeyConditionExpression.partitionKey +import zio.dynamodb.KeyConditionExpression.sortKey import zio.dynamodb.SortKeyExpression.SortKey import zio.aws.{ dynamodb, netty } import zio._ @@ -116,6 +118,12 @@ object LiveSpec extends ZIOSpecDefault { AttributeDefinition.attrDefnString(name) ) + def sortKeyStringTableWithKeywords(tableName: String) = + createTable(tableName, KeySchema("and", "source"), BillingMode.PayPerRequest)( + AttributeDefinition.attrDefnString("and"), + AttributeDefinition.attrDefnString("source") + ) + private def managedTable(tableDefinition: String => CreateTable) = ZIO .acquireRelease( @@ -150,6 +158,17 @@ object LiveSpec extends ZIOSpecDefault { } } + def withKeywordsTable( + f: String => ZIO[DynamoDBExecutor, Throwable, TestResult] + ) = + ZIO.scoped { + managedTable(sortKeyStringTableWithKeywords).flatMap { table => + for { + result <- f(table.value) + } yield result + } + } + def withDefaultAndNumberTables( f: (String, String) => ZIO[DynamoDBExecutor, Throwable, TestResult] ) = @@ -179,8 +198,67 @@ object LiveSpec extends ZIOSpecDefault { val (id, num, ttl) = ProjectionExpression.accessors[ExpressionAttrNames] } + final case class ExpressionAttrNames2(and: String, source: String, ttl: Option[Long]) + object ExpressionAttrNames2 { + implicit val schema: Schema.CaseClass3[String, String, Option[Long], ExpressionAttrNames2] = + DeriveSchema.gen[ExpressionAttrNames2] + val (and, source, ttl) = ProjectionExpression.accessors[ExpressionAttrNames2] + } + + val debugSuite = suite("debug")( + test("queryAll should handle keywords in primary key names using high level API") { + withKeywordsTable { tableName => + val query = DynamoDBQuery + .queryAll[ExpressionAttrNames2](tableName) + .whereKey(ExpressionAttrNames2.and === "and1" && ExpressionAttrNames2.source === "source1") + .filter(ExpressionAttrNames2.ttl.notExists) + query.execute.flatMap(_.runDrain).exit.map { result => + assert(result)(succeeds(isUnit)) + } + } + }, + test("queryAll should handle keywords in primary key names using low level API") { + withKeywordsTable { tableName => + val query = DynamoDBQuery + .queryAll[ExpressionAttrNames2](tableName) + .whereKey(partitionKey("and") === "and1" && sortKey("source") === "source1") + .filter(ExpressionAttrNames2.ttl.notExists) + query.execute.flatMap(_.runDrain).exit.map { result => + assert(result)(succeeds(isUnit)) + } + } + } + ) + .provideSomeLayerShared[TestEnvironment]( + testLayer.orDie + ) @@ nondeterministic + val mainSuite: Spec[TestEnvironment, Any] = suite("live test")( + suite("key words in Key Condition Expressions")( + test("queryAll should handle keywords in primary key name using high level API") { + withKeywordsTable { tableName => + val query = DynamoDBQuery + .queryAll[ExpressionAttrNames2](tableName) + .whereKey(ExpressionAttrNames2.and === "and1" && ExpressionAttrNames2.source === "source1") + .filter(ExpressionAttrNames2.ttl.notExists) + query.execute.flatMap(_.runDrain).exit.map { result => + assert(result)(succeeds(isUnit)) + } + } + }, + test("queryAll should handle keywords in primary key name using low level API") { + withKeywordsTable { tableName => + val query = DynamoDBQuery + .queryAll[ExpressionAttrNames2](tableName) + .whereKey(partitionKey("and") === "and1" && sortKey("source") === "source1") + .filter(ExpressionAttrNames2.ttl.notExists) + query.execute.flatMap(_.runDrain).exit.map { result => + assert(result)(succeeds(isUnit)) + } + } + } + ), suite("keywords in expression attribute names")( suite("using high level api")( test("scanAll should handle keyword") { diff --git a/dynamodb/src/main/scala/zio/dynamodb/KeyConditionExpression.scala b/dynamodb/src/main/scala/zio/dynamodb/KeyConditionExpression.scala index ac91437a3..ff40d4862 100644 --- a/dynamodb/src/main/scala/zio/dynamodb/KeyConditionExpression.scala +++ b/dynamodb/src/main/scala/zio/dynamodb/KeyConditionExpression.scala @@ -29,9 +29,13 @@ sealed trait KeyConditionExpression extends Renderable { self => } object KeyConditionExpression { + + def getOrInsert[From](primaryKeyName: String): AliasMapRender[String] = + AliasMapRender.getOrInsert(ProjectionExpression.MapElement[From, String](Root, primaryKeyName)) private[dynamodb] final case class And(left: PartitionKeyExpression, right: SortKeyExpression) extends KeyConditionExpression - def partitionKey(key: String): PartitionKey = PartitionKey(key) + def partitionKey(key: String): PartitionKey = PartitionKey(key) + def sortKey(key: String): SortKey = SortKey(key) /** * Create a KeyConditionExpression from a ConditionExpression @@ -156,7 +160,10 @@ sealed trait PartitionKeyExpression extends KeyConditionExpression { self => override def render: AliasMapRender[String] = self match { case PartitionKeyExpression.Equals(left, right) => - AliasMapRender.getOrInsert(right).map(v => s"${left.keyName} = $v") + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} = $v" } } object PartitionKeyExpression { @@ -171,55 +178,46 @@ sealed trait SortKeyExpression { self => def render: AliasMapRender[String] = self match { case SortKeyExpression.Equals(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} = $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} = $v" case SortKeyExpression.LessThan(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} < $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} < $v" case SortKeyExpression.NotEqual(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} <> $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} <> $v" case SortKeyExpression.GreaterThan(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} > $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} > $v" case SortKeyExpression.LessThanOrEqual(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} <= $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} <= $v" case SortKeyExpression.GreaterThanOrEqual(left, right) => - AliasMapRender - .getOrInsert(right) - .map { v => - s"${left.keyName} >= $v" - } + for { + v <- AliasMapRender.getOrInsert(right) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} >= $v" case SortKeyExpression.Between(left, min, max) => - AliasMapRender - .getOrInsert(min) - .flatMap(min => - AliasMapRender.getOrInsert(max).map { max => - s"${left.keyName} BETWEEN $min AND $max" - } - ) + for { + min2 <- AliasMapRender.getOrInsert(min) + max2 <- AliasMapRender.getOrInsert(max) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"${keyName} BETWEEN $min2 AND $max2" case SortKeyExpression.BeginsWith(left, value) => - AliasMapRender - .getOrInsert(value) - .map { v => - s"begins_with(${left.keyName}, $v)" - } + for { + v <- AliasMapRender.getOrInsert(value) + keyName <- KeyConditionExpression.getOrInsert(left.keyName) + } yield s"begins_with(${keyName}, $v)" } } diff --git a/dynamodb/src/test/scala/zio/dynamodb/AliasMapRenderSpec.scala b/dynamodb/src/test/scala/zio/dynamodb/AliasMapRenderSpec.scala index e2f1776ce..0f32a793d 100644 --- a/dynamodb/src/test/scala/zio/dynamodb/AliasMapRenderSpec.scala +++ b/dynamodb/src/test/scala/zio/dynamodb/AliasMapRenderSpec.scala @@ -322,127 +322,148 @@ object AliasMapRenderSpec extends ZIOSpecDefault { suite("SortKeyExpression")( test("Equals") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.Equals(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num = :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 = :v0")) }, test("LessThan") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.LessThan(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num < :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 < :v0")) }, test("NotEqual") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.NotEqual(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num <> :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 <> :v0")) }, test("GreaterThan") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.GreaterThan(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num > :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 > :v0")) }, test("LessThanOrEqual") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.LessThanOrEqual(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num <= :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 <= :v0")) }, test("GreaterThanOrEqual") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.GreaterThanOrEqual(SortKeyExpression.SortKey("num"), one).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num >= :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 >= :v0")) }, test("Between") { val map = Map( - avKey(one) -> ":v0", - avKey(two) -> ":v1" + avKey(one) -> ":v0", + avKey(two) -> ":v1", + pathSegment(Root, "num") -> "#n2", + fullPath($("num")) -> "#n2" ) val (aliasMap, expression) = SortKeyExpression.Between(SortKeyExpression.SortKey("num"), one, two).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 2))) && - assert(expression)(equalTo("num BETWEEN :v0 AND :v1")) + assert(aliasMap)(equalTo(AliasMap(map, 3))) && + assert(expression)(equalTo("#n2 BETWEEN :v0 AND :v1")) }, test("BeginsWith") { val map = Map( - avKey(name) -> ":v0" + avKey(name) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = SortKeyExpression.BeginsWith(SortKeyExpression.SortKey("num"), name).render.execute - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("begins_with(num, :v0)")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("begins_with(#n1, :v0)")) } ), suite("PartitionKeyExpression")( test("Equals") { val map = Map( - avKey(one) -> ":v0" + avKey(one) -> ":v0", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1" ) val (aliasMap, expression) = PartitionKeyExpression .Equals(PartitionKeyExpression.PartitionKey("num"), one) .render .execute - - assert(aliasMap)(equalTo(AliasMap(map, 1))) && - assert(expression)(equalTo("num = :v0")) + assert(aliasMap)(equalTo(AliasMap(map, 2))) && + assert(expression)(equalTo("#n1 = :v0")) } ), test("And") { val map = Map( - avKey(two) -> ":v0", - avKey(one) -> ":v1", - avKey(three) -> ":v2" + avKey(two) -> ":v0", + avKey(one) -> ":v2", + avKey(three) -> ":v3", + pathSegment(Root, "num") -> "#n1", + fullPath($("num")) -> "#n1", + pathSegment(Root, "num2") -> "#n4", + fullPath($("num2")) -> "#n4" ) val (aliasMap, expression) = KeyConditionExpression .And( PartitionKeyExpression .Equals(PartitionKeyExpression.PartitionKey("num"), two), - SortKeyExpression.Between(SortKeyExpression.SortKey("num"), one, three) + SortKeyExpression.Between(SortKeyExpression.SortKey("num2"), one, three) ) .render .execute - assert(aliasMap)(equalTo(AliasMap(map, 3))) && - assert(expression)(equalTo("num = :v0 AND num BETWEEN :v1 AND :v2")) + assert(aliasMap)(equalTo(AliasMap(map, 5))) && + assert(expression)(equalTo("#n1 = :v0 AND #n4 BETWEEN :v2 AND :v3")) } ), suite("AttributeValueType")(