Skip to content

Commit

Permalink
Fix issue when VersionedResult can lose first row of the data
Browse files Browse the repository at this point in the history
#61

This PR fixes the issue which can occur when the method `getLatestVersion` of the `HasVersion` typeclass peeks the first element of the `VersionedBatch`.

In this case the stateful `ResultSet` inside this class moves forward and the first row of the set is being lost.

To avoid this issue, the `CanPeekHead` trait was added. The implementation of this trait saves the first row if the `ResultSet` if the head of the list was requested.
  • Loading branch information
s-vitaliy committed Nov 8, 2024
1 parent 50f2fca commit 4dda70f
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package services.mssql
import models.{ArcaneSchema, ArcaneType, Field}
import services.base.{CanAdd, SchemaProvider}
import services.mssql.MsSqlConnection.{DATE_PARTITION_KEY, UPSERT_MERGE_KEY, VersionedBatch, toArcaneType}
import services.mssql.base.QueryResult
import services.mssql.base.{CanPeekHead, QueryResult}
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
Expand Down Expand Up @@ -90,7 +90,7 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
* @param arcaneSchema The schema for the data produced by Arcane.
* @return A future containing the result of the backfill.
*/
def backfill(arcaneSchema: ArcaneSchema)(using queryRunner: QueryRunner): Future[QueryResult[LazyQueryResult.OutputType]] =
def backfill(arcaneSchema: ArcaneSchema)(using queryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult]): Future[QueryResult[LazyQueryResult.OutputType]] =
for query <- QueryProvider.getBackfillQuery(this)
result <- queryRunner.executeQuery(query, connection, LazyQueryResult.apply)
yield result
Expand All @@ -101,14 +101,16 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
* @param lookBackInterval The look back interval for the query.
* @return A future containing the changes in the database since the given version and the latest observed version.
*/
def getChanges(maybeLatestVersion: Option[Long], lookBackInterval: Duration)(using queryRunner: QueryRunner): Future[VersionedBatch] =
def getChanges(maybeLatestVersion: Option[Long], lookBackInterval: Duration)
(using queryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult],
versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]]): Future[VersionedBatch] =
val query = QueryProvider.getChangeTrackingVersionQuery(connectionOptions.databaseName, maybeLatestVersion, lookBackInterval)

for versionResult <- queryRunner.executeQuery(query, connection, (st, rs) => ScalarQueryResult.apply(st, rs, readChangeTrackingVersion))
for versionResult <- versionQueryRunner.executeQuery(query, connection, (st, rs) => ScalarQueryResult.apply(st, rs, readChangeTrackingVersion))
version = versionResult.read.getOrElse(Long.MaxValue)
changesQuery <- QueryProvider.getChangesQuery(this, version - 1)
result <- queryRunner.executeQuery(changesQuery, connection, LazyQueryResult.apply)
yield (result, version)
yield MsSqlConnection.ensureHead((result, maybeLatestVersion.getOrElse(0)))

private def readChangeTrackingVersion(resultSet: ResultSet): Option[Long] =
resultSet.getMetaData.getColumnType(1) match
Expand Down Expand Up @@ -216,8 +218,15 @@ object MsSqlConnection:
/**
* Represents a versioned batch of data.
*/
type VersionedBatch = (QueryResult[LazyQueryResult.OutputType], Long)
type VersionedBatch = (QueryResult[LazyQueryResult.OutputType] & CanPeekHead[LazyQueryResult.OutputType], Long)

/**
* Ensures that the head of the result (if any) saved and cannot be lost
* This is required to let the head function work properly.
*/
private def ensureHead(result: VersionedBatch): VersionedBatch =
val (queryResult, version) = result
(queryResult.peekHead, version)

object QueryProvider:
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package services.mssql.base
/**
* Represents the result of a query to a SQL database.
*/
trait QueryResult[Output] {
trait QueryResult[Output]:

/**
* The output type of the query result.
Expand All @@ -18,4 +18,15 @@ trait QueryResult[Output] {
*/
def read: OutputType

}
/**
* Represents a query result that can peek the head of the result.
*
* @tparam Output The type of the output of the query.
*/
trait CanPeekHead[Output]:
/**
* Peeks the head of the result of the SQL query mapped to an output type.
*
* @return The head of the result of the query.
*/
def peekHead: QueryResult[Output] & CanPeekHead[Output]
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package services.mssql.query

import models.{DataCell, DataRow}
import services.mssql.MsSqlConnection.toArcaneType
import services.mssql.base.{QueryResult, ResultSetOwner}
import services.mssql.base.{CanPeekHead, QueryResult, ResultSetOwner}

import java.sql.{ResultSet, Statement}
import scala.annotation.tailrec
Expand All @@ -16,7 +16,8 @@ import scala.util.{Failure, Success, Try}
* @param statement The statement used to execute the query.
* @param resultSet The result set of the query.
*/
class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet) extends QueryResult[LazyList[DataRow]] with ResultSetOwner:
class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet, eagerHead: List[DataRow]) extends QueryResult[LazyList[DataRow]]
with ResultSetOwner with CanPeekHead[LazyList[DataRow]]:

/**
* Reads the result of the query.
Expand All @@ -25,7 +26,7 @@ class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet)
*/
override def read: this.OutputType =
val columns = resultSet.getMetaData.getColumnCount
LazyList.continually(resultSet)
eagerHead.to(LazyList) #::: LazyList.continually(resultSet)
.takeWhile(_.next())
.map(row => {
toDataRow(row, columns, List.empty) match {
Expand All @@ -34,6 +35,14 @@ class LazyQueryResult(protected val statement: Statement, resultSet: ResultSet)
}
})

/**
* Peeks the head of the result of the SQL query mapped to an output type.
*
* @return The head of the result of the query.
*/
def peekHead: QueryResult[this.OutputType] & CanPeekHead[this.OutputType] =
new LazyQueryResult(statement, resultSet, read.headOption.toList)

@tailrec
private def toDataRow(row: ResultSet, columns: Int, acc: DataRow): Try[DataRow] =
if columns == 0 then Success(acc)
Expand Down Expand Up @@ -62,5 +71,5 @@ object LazyQueryResult {
* @param resultSet The result set of the query.
* @return The new [[LazyQueryResult]] object.
*/
def apply(statement: Statement, resultSet: ResultSet): LazyQueryResult = new LazyQueryResult(statement, resultSet)
def apply(statement: Statement, resultSet: ResultSet): LazyQueryResult = new LazyQueryResult(statement, resultSet, List.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import scala.concurrent.{Future, blocking}
* cursor.
*
*/
class QueryRunner:
class QueryRunner[Output, QueryResultType <: QueryResult[Output]]:

/**
* A factory for creating a QueryResult object from a statement and a result set.
*
* @tparam Output The type of the output of the query.
*/
private type ResultFactory[Output] = (Statement, ResultSet) => QueryResult[Output]
private type ResultFactory = (Statement, ResultSet) => QueryResultType

private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global

Expand All @@ -33,7 +33,7 @@ class QueryRunner:
* @param connection The connection to execute the query on.
* @return The result of the query.
*/
def executeQuery[Result](query: MsSqlQuery, connection: Connection, resultFactory: ResultFactory[Result]): Future[QueryResult[Result]] =
def executeQuery(query: MsSqlQuery, connection: Connection, resultFactory: ResultFactory): Future[QueryResultType] =
Future {
val statement = connection.createStatement()
val resultSet = blocking {
Expand All @@ -44,4 +44,5 @@ class QueryRunner:


object QueryRunner:
def apply(): QueryRunner = new QueryRunner()
def apply[O, T <: QueryResult[O]](): QueryRunner[O, T] = new QueryRunner[O, T]()

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ScalarQueryResult[Result](val statement: Statement, resultSet: ResultSet,
None
case _ => None


/**
* Companion object for [[LazyQueryResult]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import models.DataCell
import services.mssql.MsSqlConnection
import services.mssql.MsSqlConnection.VersionedBatch
import services.mssql.query.LazyQueryResult.OutputType
import services.mssql.query.QueryRunner
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}
import services.streaming.base.{HasVersion, StreamLifetimeService}

import zio.stream.ZStream
Expand Down Expand Up @@ -36,7 +36,8 @@ given HasVersion[VersionedBatch] with


class StreamGraphBuilder(msSqlConnection: MsSqlConnection) {
implicit val queryRunner: QueryRunner = QueryRunner()
implicit val dataQueryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult] = QueryRunner()
implicit val versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]] = QueryRunner()

/**
* Builds a stream that reads the changes from the database.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,40 @@
package com.sneaksanddata.arcane.framework
package services.streaming

import services.mssql.query.QueryRunner
import services.mssql.query.{LazyQueryResult, QueryRunner, ScalarQueryResult}
import services.mssql.{ConnectionOptions, MsSqlConnection}
import services.streaming.base.StreamLifetimeService

import com.microsoft.sqlserver.jdbc.SQLServerDriver
import org.scalatest.*
import org.scalatest.matchers.must.Matchers
import org.scalatest.matchers.should.Matchers.*
import zio.stream.ZSink
import zio.{Runtime, Unsafe}

import java.sql.Connection
import java.util.Properties
import scala.List
import scala.concurrent.Future
import scala.language.postfixOps
import scala.util.Using

case class TestConnectionInfo(connectionOptions: ConnectionOptions, connection: Connection)

class StreamGraphBuilderTests extends flatspec.AsyncFlatSpec with Matchers:
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
private implicit val queryRunner: QueryRunner = QueryRunner()
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
private implicit val dataQueryRunner: QueryRunner[LazyQueryResult.OutputType, LazyQueryResult] = QueryRunner()
private implicit val versionQueryRunner: QueryRunner[Option[Long], ScalarQueryResult[Long]] = QueryRunner()
private val connectionUrl = "jdbc:sqlserver://localhost;encrypt=true;trustServerCertificate=true;username=sa;password=tMIxN11yGZgMC;databaseName=arcane"
private val runtime = Runtime.default

def createDb(): TestConnectionInfo =
def createDb(tableName: String): TestConnectionInfo =
val dr = new SQLServerDriver()
val con = dr.connect(connectionUrl, new Properties())
val query = "IF NOT EXISTS (SELECT * FROM sys.databases WHERE name = 'arcane') BEGIN CREATE DATABASE arcane; alter database Arcane set CHANGE_TRACKING = ON (CHANGE_RETENTION = 2 DAYS, AUTO_CLEANUP = ON); END;"
val statement = con.createStatement()
statement.execute(query)
createTable(con)
createTable(con, tableName)
TestConnectionInfo(
ConnectionOptions(
connectionUrl,
Expand All @@ -40,24 +43,27 @@ class StreamGraphBuilderTests extends flatspec.AsyncFlatSpec with Matchers:
"StreamGraphBuilderTests",
Some("format(getdate(), 'yyyyMM')")), con)

def createTable(con: Connection): Unit =
val query = "use arcane; drop table if exists dbo.StreamGraphBuilderTests; create table dbo.StreamGraphBuilderTests(x int not null, y int)"
def insertData(con: Connection, tableName: String): Unit =
val sql = s"use arcane; insert into dbo.$tableName values(?, ?)";
Using(con.prepareStatement(sql)) { insertStatement =>
for i <- 0 to 9 do
insertStatement.setInt(1, i)
insertStatement.setInt(2, i + 1)
insertStatement.addBatch()
insertStatement.clearParameters()
insertStatement.executeBatch()
}

def createTable(con: Connection, tableName: String): Unit =
val query = s"use arcane; drop table if exists dbo.$tableName; create table dbo.StreamGraphBuilderTests(x int not null, y int)"
val statement = con.createStatement()
statement.executeUpdate(query)

val createPKCmd = "use arcane; alter table dbo.StreamGraphBuilderTests add constraint pk_StreamGraphBuilderTests primary key(x);"
val createPKCmd = s"use arcane; alter table dbo.$tableName add constraint pk_StreamGraphBuilderTests primary key(x);"
statement.executeUpdate(createPKCmd)

val enableCtCmd = "use arcane; alter table dbo.StreamGraphBuilderTests enable change_tracking;"
val enableCtCmd = s"use arcane; alter table dbo.$tableName enable change_tracking;"
statement.executeUpdate(enableCtCmd)

val insertStatement = con.prepareStatement("use arcane; insert into dbo.StreamGraphBuilderTests values(?, ?)")
for i <- 0 to 9 do
insertStatement.setInt(1, i)
insertStatement.setInt(2, i + 1)
insertStatement.addBatch()
insertStatement.clearParameters()
insertStatement.executeBatch()
statement.close()


Expand All @@ -69,12 +75,18 @@ class StreamGraphBuilderTests extends flatspec.AsyncFlatSpec with Matchers:
statement.execute(query)


def withFreshDatabase(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
def withFreshTable(tableName: String)(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
removeDb()
val conn = createDb(tableName)
insertData(conn.connection, tableName)
test(conn)

def withEmptyTable(tableName: String)(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
removeDb()
val conn = createDb()
val conn = createDb(tableName)
test(conn)

"StreamGraph" should "not duplicate data on the first iteration" in withFreshDatabase { dbInfo =>
"StreamGraph" should "not duplicate data on the first iteration" in withFreshTable("StreamGraphBuilderTests") { dbInfo =>
val streamGraphBuilder = new StreamGraphBuilder(MsSqlConnection(dbInfo.connectionOptions))

val lifetime = TestStreamLifetimeService(3)
Expand All @@ -89,7 +101,7 @@ class StreamGraphBuilderTests extends flatspec.AsyncFlatSpec with Matchers:
}
}

"StreamGraph" should "be able to generate changes stream" in withFreshDatabase { dbInfo =>
"StreamGraph" should "be able to generate changes stream" in withFreshTable("StreamGraphBuilderTests") { dbInfo =>
val streamGraphBuilder = new StreamGraphBuilder(MsSqlConnection(dbInfo.connectionOptions))

val lifetime = TestStreamLifetimeService(3, counter => {
Expand Down

0 comments on commit 4dda70f

Please sign in to comment.