Skip to content

Commit

Permalink
SqlDelight client cr updates and fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mpawliszyn committed Oct 22, 2024
1 parent 9958704 commit 0b7b5b6
Show file tree
Hide file tree
Showing 20 changed files with 393 additions and 327 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ internal class EmbeddedBackfillRun<B : Backfill>(
}

override fun scanRemaining() {
do {
while (!finishedScanning()) {
singleScan()
} while (!finishedScanning())
}
}

override fun finishedScanning() = scanProgress.all { it.value.done }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package app.cash.backfila.client.sqldelight.plugin

import app.cash.sqldelight.gradle.SqlDelightExtension
import java.io.Serializable
import org.gradle.api.NamedDomainObjectContainer
import org.gradle.api.Plugin
import org.gradle.api.Project
import org.gradle.api.file.Directory
import org.gradle.api.provider.Property
import org.gradle.api.provider.Provider
import org.jetbrains.kotlin.gradle.dsl.kotlinExtension

class BackfilaSqlDelightGradlePlugin : Plugin<Project> {
Expand All @@ -16,14 +18,8 @@ class BackfilaSqlDelightGradlePlugin : Plugin<Project> {
check(!backfill.name.matches(Regex("\\s"))) { "Backfill `name` is not allowed to contain whitespace." }
val baseSqlDirectory = target.layout.buildDirectory.dir("backfilaSqlDelight/${backfill.name}/sql")
val baseKotlinDirectory = target.layout.buildDirectory.dir("backfilaSqlDelight/${backfill.name}/kotlin")
val packageProvider = backfill.backfill.map {
val fullDatabaseName = it.database
fullDatabaseName.substring(0, fullDatabaseName.lastIndexOf('.'))
}
val databaseClassNameProvider = backfill.backfill.map {
val fullDatabaseName = it.database
fullDatabaseName.substring(fullDatabaseName.lastIndexOf('.') + 1)
}
val packageProvider = backfill.backfill.databaseProvider().packageName()
val databaseClassNameProvider = backfill.backfill.databaseProvider().className()

val sqlTask = target.tasks.register(
"generateBackfilaRecordSourceSql${backfill.name.replaceFirstChar { it.uppercase() }}",
Expand All @@ -33,7 +29,7 @@ class BackfilaSqlDelightGradlePlugin : Plugin<Project> {
it.sqlDirectory.set(baseSqlDirectory.map { baseDir -> baseDir.dir(packageProvider.get()) })
}

val sqlDelightExtension = target.extensions.findByType(app.cash.sqldelight.gradle.SqlDelightExtension::class.java)
val sqlDelightExtension = target.extensions.findByType(SqlDelightExtension::class.java)
check(sqlDelightExtension != null) {
"The Backfila gradle plugin requires the SqlDelight gradle plugin to function."
}
Expand Down Expand Up @@ -66,6 +62,18 @@ class BackfilaSqlDelightGradlePlugin : Plugin<Project> {
}
}

private fun Property<SqlDelightRecordSource>.databaseProvider(): Provider<String> = map {
it.database
}

private fun Provider<String>.packageName(): Provider<String> = map {
it.substring(0, it.lastIndexOf('.'))
}

private fun Provider<String>.className(): Provider<String> = map {
it.substring(it.lastIndexOf('.') + 1)
}

abstract class BackfilaSqlDelightExtension {
abstract val backfills: NamedDomainObjectContainer<SqlDelightRecordSourceEntry>

Expand All @@ -75,6 +83,7 @@ abstract class BackfilaSqlDelightExtension {
tableName: String,
keyName: String,
keyType: String,
keyEncoder: String,
recordColumns: String,
recordType: String,
whereClause: String = "1 = 1",
Expand All @@ -87,6 +96,7 @@ abstract class BackfilaSqlDelightExtension {
tableName = tableName,
keyName = keyName,
keyType = keyType,
keyEncoder = keyEncoder,
recordColumns = recordColumns,
recordType = recordType,
whereClause = whereClause,
Expand All @@ -106,6 +116,7 @@ data class SqlDelightRecordSource(
val tableName: String,
val keyName: String, // TODO: Eventually also support compound keys.
val keyType: String, // TODO: Get this information directly from SQLDelight
val keyEncoder: String, // TODO: Automatically set this when it can.
val recordColumns: String,
val recordType: String, // TODO: Get this information directly from SQLDelight
val whereClause: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ package app.cash.backfila.client.sqldelight.plugin
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier.OVERRIDE
import com.squareup.kotlinpoet.KModifier.PRIVATE
import com.squareup.kotlinpoet.LONG
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeSpec
import org.gradle.api.DefaultTask
import org.gradle.api.file.DirectoryProperty
import org.gradle.api.provider.Property
Expand All @@ -26,71 +31,138 @@ abstract class GenerateBackfilaRecordSourceQueriesTask : DefaultTask() {
val backfillConfig = backfill.get()
val name = backfillConfig.name.replaceFirstChar { it.uppercase() }
val packageName = packageName.get()
val fileName = "${name}RecordSourceQueries"
val className = "${name}RecordSourceConfig"
val targetDirectory = kotlinDirectory.asFile.get()
targetDirectory.mkdirs()

val databaseClass = ClassName.bestGuess(backfillConfig.database)
// Find and specify types
val databaseType = ClassName.bestGuess(backfillConfig.database)
val queriesFunctionName = "${backfillConfig.name}Queries".replaceFirstChar { it.lowercase() }

val keyType = backfillConfig.keyType
val recordType = backfillConfig.recordType
val recordSourceQueriesType = ClassName("app.cash.backfila.client.sqldelight", "SqlDelightRecordSourceQueries")
val returnType = recordSourceQueriesType.parameterizedBy(ClassName.bestGuess(keyType), ClassName.bestGuess(recordType))
val keyType = ClassName.bestGuess(backfillConfig.keyType)
val keyEncoderType = ClassName("app.cash.backfila.client.sqldelight", "KeyEncoder")
.parameterizedBy(keyType)
val myKeyEncoderType = ClassName.bestGuess(backfillConfig.keyEncoder)

val poetFile = FileSpec.builder(packageName, fileName)
.addFunction(
FunSpec.builder("get${name}Queries")
.addParameter("database", databaseClass)
.returns(returnType)
.addStatement(
"""
| return %T.create(
| database.$queriesFunctionName.selectAbsoluteRange { min, max -> %T.MinMax(min, max) },
| { rangeStart: $keyType, rangeEnd: $keyType, scanSize: Long ->
| database.$queriesFunctionName.selectInitialMaxBound(rangeStart, rangeEnd, scanSize) {
| %T.NullKeyContainer(
| it,
| )
| }
| },
| { previousEndKey: $keyType, rangeEnd: $keyType, scanSize: Long ->
| database.$queriesFunctionName.selectNextMaxBound(
| previousEndKey,
| rangeEnd,
| scanSize,
| ) { %T.NullKeyContainer(it) }
| },
| { rangeStart: $keyType, rangeEnd: $keyType, offset: Long -> database.$queriesFunctionName.produceInitialBatchFromRange(rangeStart, rangeEnd, offset) },
| { previousEndKey: $keyType, rangeEnd: $keyType, offset: Long -> database.$queriesFunctionName.produceNextBatchFromRange(previousEndKey, rangeEnd, offset) },
| { rangeStart: $keyType, boundingMax: $keyType -> database.$queriesFunctionName.countInitialBatchMatches(rangeStart, boundingMax) },
| { previousEndKey: $keyType, boundingMax: $keyType -> database.$queriesFunctionName.countNextBatchMatches(previousEndKey, boundingMax) },
| { rangeStart: $keyType, rangeEnd: $keyType ->
| database.$queriesFunctionName.getInitialStartKeyAndScanCount(rangeStart, rangeEnd) { min, count ->
| %T.MinAndCount(
| min,
| count,
| )
| }
| },
| { previousEndKey: $keyType, rangeEnd: $keyType ->
| database.$queriesFunctionName.getNextStartKeyAndScanCount(
| previousEndKey,
| rangeEnd,
| ) { min, count -> %T.MinAndCount(min, count) }
| },
| { start: $keyType, end: $keyType -> database.$queriesFunctionName.getBatch(start, end) },
| )
|
""".trimMargin(),
recordSourceQueriesType,
recordSourceQueriesType,
recordSourceQueriesType,
recordSourceQueriesType,
recordSourceQueriesType,
recordSourceQueriesType,
)
.build(),
val parameterizedRecordType = ClassName("app.cash.backfila.client.sqldelight", "SqlDelightRecordSourceConfig")
.parameterizedBy(keyType, ClassName.bestGuess(backfillConfig.recordType))

val minMaxType = ClassName("app.cash.backfila.client.sqldelight", "MinMax")
.parameterizedBy(keyType)
val nullKeyContainerType = ClassName("app.cash.backfila.client.sqldelight", "NullKeyContainer")
.parameterizedBy(keyType)
val minAndCountType = ClassName("app.cash.backfila.client.sqldelight", "MinAndCount")
.parameterizedBy(keyType)

// Return query types
val queryType = ClassName("app.cash.sqldelight", "Query")
val minMaxQueryType = queryType.parameterizedBy(minMaxType)
val nullKeyContainerQueryType = queryType.parameterizedBy(nullKeyContainerType)
val minAndCountQueryType = queryType.parameterizedBy(minAndCountType)
val keyQueryType = queryType.parameterizedBy(keyType)
val longQueryType = queryType.parameterizedBy(LONG)
val recordQueryType = queryType.parameterizedBy(ClassName.bestGuess(backfillConfig.recordType))

// Generate the file.
val poetFile = FileSpec.builder(packageName, className)
.addType(
TypeSpec.classBuilder(ClassName(packageName, className))
.addSuperinterface(parameterizedRecordType)
.primaryConstructor(
FunSpec.constructorBuilder()
.addParameter("database", databaseType)
.build(),
).addProperty(
PropertySpec.builder("database", databaseType, PRIVATE)
.initializer("database")
.build(),
).addProperty(
PropertySpec.builder("keyEncoder", keyEncoderType, OVERRIDE)
.initializer("%T", myKeyEncoderType)
.build(),
).addFunction(
FunSpec.builder("selectAbsoluteRange")
.returns(minMaxQueryType)
.addStatement("return database.%L.selectAbsoluteRange { min, max -> %T(min, max) }", queriesFunctionName, minMaxType)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("selectInitialMaxBound")
.addParameter("rangeStart", keyType)
.addParameter("rangeEnd", keyType)
.addParameter("scanSize", LONG)
.returns(nullKeyContainerQueryType)
.addStatement("return database.%L.selectInitialMaxBound(rangeStart, rangeEnd, scanSize) { %T(it) }", queriesFunctionName, nullKeyContainerType)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("selectNextMaxBound")
.addParameter("previousEndKey", keyType)
.addParameter("rangeEnd", keyType)
.addParameter("scanSize", LONG)
.returns(nullKeyContainerQueryType)
.addStatement("return database.%L.selectNextMaxBound(previousEndKey, rangeEnd, scanSize) { %T(it) }", queriesFunctionName, nullKeyContainerType)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("produceInitialBatchFromRange")
.addParameter("rangeStart", keyType)
.addParameter("rangeEnd", keyType)
.addParameter("offset", LONG)
.returns(keyQueryType)
.addStatement("return database.%L.produceInitialBatchFromRange(rangeStart, rangeEnd, offset)", queriesFunctionName)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("produceNextBatchFromRange")
.addParameter("previousEndKey", keyType)
.addParameter("rangeEnd", keyType)
.addParameter("offset", LONG)
.returns(keyQueryType)
.addStatement("return database.%L.produceNextBatchFromRange(previousEndKey, rangeEnd, offset)", queriesFunctionName)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("countInitialBatchMatches")
.addParameter("rangeStart", keyType)
.addParameter("boundingMax", keyType)
.returns(longQueryType)
.addStatement("return database.%L.countInitialBatchMatches(rangeStart, boundingMax)", queriesFunctionName)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("countNextBatchMatches")
.addParameter("previousEndKey", keyType)
.addParameter("boundingMax", keyType)
.returns(longQueryType)
.addStatement("return database.%L.countNextBatchMatches(previousEndKey, boundingMax)", queriesFunctionName)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("getInitialStartKeyAndScanCount")
.addParameter("rangeStart", keyType)
.addParameter("rangeEnd", keyType)
.returns(minAndCountQueryType)
.addStatement("return database.%L.getInitialStartKeyAndScanCount(rangeStart, rangeEnd) { min, count -> %T(min, count) }", queriesFunctionName, minAndCountType)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("getNextStartKeyAndScanCount")
.addParameter("previousEndKey", keyType)
.addParameter("rangeEnd", keyType)
.returns(minAndCountQueryType)
.addStatement("return database.%L.getNextStartKeyAndScanCount(previousEndKey, rangeEnd) { min, count -> %T(min, count) }", queriesFunctionName, minAndCountType)
.addModifiers(OVERRIDE)
.build(),
).addFunction(
FunSpec.builder("getBatch")
.addParameter("start", keyType)
.addParameter("end", keyType)
.returns(recordQueryType)
.addStatement("return database.%L.getBatch(start, end)", queriesFunctionName)
.addModifiers(OVERRIDE)
.build(),
).build(),
).build()

poetFile.writeTo(targetDirectory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ backfilaSqlDelight {
tableName = "hockeyPlayer",
keyName = "player_number",
keyType = "kotlin.Int",
keyEncoder = "app.cash.backfila.client.sqldelight.IntKeyEncoder",
recordColumns = "*",
recordType = "app.cash.backfila.client.sqldelight.hockeydata.HockeyPlayer"
)
Expand Down

This file was deleted.

Loading

0 comments on commit 0b7b5b6

Please sign in to comment.