Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue when VersionedResult can lose first row of the data #79

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package services.mssql
import models.{ArcaneSchema, ArcaneType, Field}
import services.base.{CanAdd, SchemaProvider}
import services.mssql.MsSqlConnection.{DATE_PARTITION_KEY, UPSERT_MERGE_KEY, VersionedBatch, toArcaneType}
import services.mssql.base.QueryResult
import services.mssql.base.{CanPeekHead, QueryResult}
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
Expand Down Expand Up @@ -90,7 +90,7 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
* @param arcaneSchema The schema for the data produced by Arcane.
* @return A future containing the result of the backfill.
*/
def backfill(arcaneSchema: ArcaneSchema)(using queryRunner: QueryRunner): Future[QueryResult[LazyQueryResult.OutputType]] =
def backfill(arcaneSchema: ArcaneSchema)(using queryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult]): Future[QueryResult[LazyQueryResult.OutputType]] =
for query <- QueryProvider.getBackfillQuery(this)
result <- queryRunner.executeQuery(query, connection, LazyQueryResult.apply)
yield result
Expand All @@ -101,14 +101,16 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
* @param lookBackInterval The look back interval for the query.
* @return A future containing the changes in the database since the given version and the latest observed version.
*/
def getChanges(maybeLatestVersion: Option[Long], lookBackInterval: Duration)(using queryRunner: QueryRunner): Future[VersionedBatch] =
def getChanges(maybeLatestVersion: Option[Long], lookBackInterval: Duration)
(using queryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult],
versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]]): Future[VersionedBatch] =
val query = QueryProvider.getChangeTrackingVersionQuery(connectionOptions.databaseName, maybeLatestVersion, lookBackInterval)

for versionResult <- queryRunner.executeQuery(query, connection, (st, rs) => ScalarQueryResult.apply(st, rs, readChangeTrackingVersion))
for versionResult <- versionQueryRunner.executeQuery(query, connection, (st, rs) => ScalarQueryResult.apply(st, rs, readChangeTrackingVersion))
version = versionResult.read.getOrElse(Long.MaxValue)
changesQuery <- QueryProvider.getChangesQuery(this, version - 1)
result <- queryRunner.executeQuery(changesQuery, connection, LazyQueryResult.apply)
yield (result, version)
yield MsSqlConnection.ensureHead((result, maybeLatestVersion.getOrElse(0)))

private def readChangeTrackingVersion(resultSet: ResultSet): Option[Long] =
resultSet.getMetaData.getColumnType(1) match
Expand Down Expand Up @@ -216,8 +218,15 @@ object MsSqlConnection:
/**
* Represents a versioned batch of data.
*/
type VersionedBatch = (QueryResult[LazyQueryResult.OutputType], Long)
type VersionedBatch = (QueryResult[LazyQueryResult.OutputType] & CanPeekHead[LazyQueryResult.OutputType], Long)

/**
* Ensures that the head of the result (if any) saved and cannot be lost
* This is required to let the head function work properly.
*/
private def ensureHead(result: VersionedBatch): VersionedBatch =
val (queryResult, version) = result
(queryResult.peekHead, version)

object QueryProvider:
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package services.mssql.base
/**
* Represents the result of a query to a SQL database.
*/
trait QueryResult[Output] {
trait QueryResult[Output]:

/**
* The output type of the query result.
Expand All @@ -18,4 +18,15 @@ trait QueryResult[Output] {
*/
def read: OutputType

}
/**
* Represents a query result that can peek the head of the result.
*
* @tparam Output The type of the output of the query.
*/
trait CanPeekHead[Output]:
/**
* Peeks the head of the result of the SQL query mapped to an output type.
*
* @return The head of the result of the query.
*/
def peekHead: QueryResult[Output] & CanPeekHead[Output]
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package services.mssql.query

import models.{DataCell, DataRow}
import services.mssql.MsSqlConnection.toArcaneType
import services.mssql.base.{QueryResult, ResultSetOwner}
import services.mssql.base.{CanPeekHead, QueryResult, ResultSetOwner}

import java.sql.{ResultSet, Statement}
import scala.annotation.tailrec
Expand All @@ -16,7 +16,8 @@ import scala.util.{Failure, Success, Try}
* @param statement The statement used to execute the query.
* @param resultSet The result set of the query.
*/
class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet) extends QueryResult[LazyList[DataRow]] with ResultSetOwner:
class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet, eagerHead: List[DataRow]) extends QueryResult[LazyList[DataRow]]
with ResultSetOwner with CanPeekHead[LazyList[DataRow]]:

/**
* Reads the result of the query.
Expand All @@ -25,7 +26,7 @@ class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet)
*/
override def read: this.OutputType =
val columns = resultSet.getMetaData.getColumnCount
LazyList.continually(resultSet)
eagerHead.to(LazyList) #::: LazyList.continually(resultSet)
.takeWhile(_.next())
.map(row => {
toDataRow(row, columns, List.empty) match {
Expand All @@ -34,6 +35,14 @@ class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet)
}
})

/**
* Peeks the head of the result of the SQL query mapped to an output type.
*
* @return The head of the result of the query.
*/
def peekHead: QueryResult[this.OutputType] & CanPeekHead[this.OutputType] =
new LazyQueryResult(statement, resultSet, read.headOption.toList)

@tailrec
private def toDataRow(row: ResultSet, columns: Int, acc: DataRow): Try[DataRow] =
if columns == 0 then Success(acc)
Expand Down Expand Up @@ -62,5 +71,5 @@ object LazyQueryResult {
* @param resultSet The result set of the query.
* @return The new [[LazyQueryResult]] object.
*/
def apply(statement: Statement, resultSet: ResultSet): LazyQueryResult = new LazyQueryResult(statement, resultSet)
def apply(statement: Statement, resultSet: ResultSet): LazyQueryResult = new LazyQueryResult(statement, resultSet, List.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import scala.concurrent.{Future, blocking}
* cursor.
*
*/
class QueryRunner:
class QueryRunner[Output, QueryResultType <: QueryResult[Output]]:

/**
* A factory for creating a QueryResult object from a statement and a result set.
*
* @tparam Output The type of the output of the query.
*/
private type ResultFactory[Output] = (Statement, ResultSet) => QueryResult[Output]
private type ResultFactory = (Statement, ResultSet) => QueryResultType

private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global

Expand All @@ -33,7 +33,7 @@ class QueryRunner:
* @param connection The connection to execute the query on.
* @return The result of the query.
*/
def executeQuery[Result](query: MsSqlQuery, connection: Connection, resultFactory: ResultFactory[Result]): Future[QueryResult[Result]] =
def executeQuery(query: MsSqlQuery, connection: Connection, resultFactory: ResultFactory): Future[QueryResultType] =
Future {
val statement = connection.createStatement()
val resultSet = blocking {
Expand All @@ -44,4 +44,5 @@ class QueryRunner:


object QueryRunner:
def apply(): QueryRunner = new QueryRunner()
def apply[O, T <: QueryResult[O]](): QueryRunner[O, T] = new QueryRunner[O, T]()

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ScalarQueryResult[Result](val statement: Statement, resultSet: ResultSet,
None
case _ => None


/**
* Companion object for [[LazyQueryResult]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import models.DataCell
import services.mssql.MsSqlConnection
import services.mssql.MsSqlConnection.VersionedBatch
import services.mssql.query.LazyQueryResult.OutputType
import services.mssql.query.QueryRunner
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}
import services.streaming.base.{HasVersion, StreamLifetimeService}

import zio.stream.ZStream
Expand Down Expand Up @@ -36,7 +36,8 @@ given HasVersion[VersionedBatch] with


class StreamGraphBuilder(msSqlConnection: MsSqlConnection) {
implicit val queryRunner: QueryRunner = QueryRunner()
implicit val dataQueryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult] = QueryRunner()
implicit val versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]] = QueryRunner()

/**
* Builds a stream that reads the changes from the database.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package services.connectors.mssql

import models.ArcaneType.{IntType, LongType, StringType}
import models.Field
import services.mssql.query.QueryRunner
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}
import services.mssql.{ConnectionOptions, MsSqlConnection, QueryProvider}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
Expand All @@ -21,36 +21,40 @@ import scala.language.postfixOps
case class TestConnectionInfo(connectionOptions: ConnectionOptions, connection: Connection)

class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
private implicit val queryRunner: QueryRunner = QueryRunner()
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
private implicit val dataQueryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult] = QueryRunner()
private implicit val versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]] = QueryRunner()

val connectionUrl = "jdbc:sqlserver://localhost;encrypt=true;trustServerCertificate=true;username=sa;password=tMIxN11yGZgMC"

def createDb(): TestConnectionInfo =
def createDb(tableName: String): TestConnectionInfo =
val dr = new SQLServerDriver()
val con = dr.connect(connectionUrl, new Properties())
val query = "IF NOT EXISTS (SELECT * FROM sys.databases WHERE name = 'arcane') BEGIN CREATE DATABASE arcane; alter database Arcane set CHANGE_TRACKING = ON (CHANGE_RETENTION = 2 DAYS, AUTO_CLEANUP = ON); END;"
val statement = con.createStatement()
statement.execute(query)
createTable(con)
createTable(tableName, con)
TestConnectionInfo(
ConnectionOptions(
connectionUrl,
"arcane",
"dbo",
"MsSqlConnectorsTests",
tableName,
Some("format(getdate(), 'yyyyMM')")), con)

def createTable(con: Connection): Unit =
val query = "use arcane; drop table if exists dbo.MsSqlConnectorsTests; create table dbo.MsSqlConnectorsTests(x int not null, y int)"
def createTable(tableName: String, con: Connection): Unit =
val query = s"use arcane; drop table if exists dbo.$tableName; create table dbo.$tableName (x int not null, y int)"
val statement = con.createStatement()
statement.executeUpdate(query)

val createPKCmd = "use arcane; alter table dbo.MsSqlConnectorsTests add constraint pk_MsSqlConnectorsTests primary key(x);"
val createPKCmd = s"use arcane; alter table dbo.$tableName add constraint pk_$tableName primary key(x);"
statement.executeUpdate(createPKCmd)

val enableCtCmd = "use arcane; alter table dbo.MsSqlConnectorsTests enable change_tracking;"
val enableCtCmd = s"use arcane; alter table dbo.$tableName enable change_tracking;"
statement.executeUpdate(enableCtCmd)

def insertData(con: Connection): Unit =
val statement = con.createStatement()
for i <- 1 to 10 do
val insertCmd = s"use arcane; insert into dbo.MsSqlConnectorsTests values($i, ${i+1})"
statement.execute(insertCmd)
Expand All @@ -71,7 +75,12 @@ class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:


def withDatabase(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
val conn = createDb()
val conn = createDb("MsSqlConnectorsTests")
insertData(conn.connection)
test(conn)

def withFreshTable(tableName: String)(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
val conn = createDb(tableName)
test(conn)

"QueryProvider" should "generate columns query" in withDatabase { dbInfo =>
Expand Down Expand Up @@ -155,6 +164,6 @@ class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:
result <- connection.getChanges(None, Duration.ofDays(1))
(_, latestVersion) = result
yield {
latestVersion should be > 0L
latestVersion should be >= 0L
}
}
Loading
Loading