Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement QueryProvider for MsSqlConnection class #70

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
select
c.COLUMN_NAME,
case when kcu.CONSTRAINT_NAME is not null then 1 else 0 end as IsPrimaryKey
from
[{dbName}].INFORMATION_SCHEMA.COLUMNS c
left join [{dbName}].INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc on c.TABLE_SCHEMA = tc.TABLE_SCHEMA and c.TABLE_NAME = tc.TABLE_NAME
left join [{dbName}].INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu on tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME and c.COLUMN_NAME = kcu.COLUMN_NAME
where
tc.CONSTRAINT_TYPE = N'PRIMARY KEY'
and tc.TABLE_NAME = N'{table}'
and tc.TABLE_SCHEMA = N'{schema}'
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
declare @currentVersion bigint = CHANGE_TRACKING_CURRENT_VERSION()

SELECT
{ChangeTrackingColumnsStatement},
@currentVersion AS 'ChangeTrackingVersion',
lower(convert(nvarchar(128), HashBytes('SHA2_256', {MERGE_EXPRESSION}),2)) as [{MERGE_KEY}]
FROM [{dbName}].[{schema}].[{tableName}] tq
RIGHT JOIN (SELECT ct.* FROM CHANGETABLE (CHANGES [{dbName}].[{schema}].[{tableName}], {lastId}) ct ) ct ON {ChangeTrackingMatchStatement}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
declare @currentVersion bigint = CHANGE_TRACKING_CURRENT_VERSION()

SELECT
{ChangeTrackingColumnsStatement},
@currentVersion AS 'ChangeTrackingVersion',
lower(convert(nvarchar(128), HashBytes('SHA2_256', {MERGE_EXPRESSION}),2)) as [{MERGE_KEY}],
{DATE_PARTITION_EXPRESSION} as [{DATE_PARTITION_KEY}]
FROM [{dbName}].[{schema}].[{tableName}] tq
RIGHT JOIN (SELECT ct.* FROM CHANGETABLE (CHANGES [{dbName}].[{schema}].[{tableName}], {lastId}) ct ) ct ON {ChangeTrackingMatchStatement}
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package com.sneaksanddata.arcane.framework
package services.mssql

import MsSqlConnection.{DATE_PARTITION_KEY, UPSERT_MERGE_KEY}

import com.microsoft.sqlserver.jdbc.SQLServerDriver

import java.sql.ResultSet
import java.util.Properties
import scala.annotation.tailrec
import scala.concurrent.{Future, blocking}
import scala.io.Source
import scala.util.Using

/**
* Represents a summary of a column in a table.
* The first element is the name of the column, and the second element is true if the column is a primary key.
*/
type ColumnSummary = (String, Boolean)

/**
* Represents a query to be executed on a Microsoft SQL Server database.
*/
type MsSqlQuery = String

/**
* Represents the connection options for a Microsoft SQL Server database.
*
* @param connectionUrl The connection URL for the database.
* @param databaseName The name of the database.
* @param schemaName The name of the schema.
* @param tableName The name of the table.
* @param partitionExpression The partition expression for the table.
*/
case class ConnectionOptions(connectionUrl: String,
databaseName: String,
schemaName: String,
tableName: String,
partitionExpression: Option[String])

/**
* Represents a connection to a Microsoft SQL Server database.
*
* @param connectionOptions The connection options for the database.
*/
class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoCloseable:
private val driver = new SQLServerDriver()
private val connection = driver.connect(connectionOptions.connectionUrl, new Properties())
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global


/**
* Gets the column summaries for the table in the database.
*
* @return A future containing the column summaries for the table in the database.
*/
def getColumnSummaries: Future[List[ColumnSummary]] =
val query = QueryProvider.getColumnSummariesQuery(connectionOptions.schemaName, connectionOptions.tableName, connectionOptions.databaseName)
Future {
val result = Using.Manager { use =>
val statement = use(connection.createStatement())
val resultSet = use(statement.executeQuery(query))
blocking {
readColumns(resultSet, List.empty)
}
}
result.get
}

/**
* Closes the connection to the database.
*/
override def close(): Unit = connection.close()

@tailrec
private def readColumns(resultSet: ResultSet, result: List[ColumnSummary]): List[ColumnSummary] =
val hasNext = resultSet.next()

if !hasNext then
return result
readColumns(resultSet, result ++ List((resultSet.getString(1), resultSet.getInt(2) == 1)))


object MsSqlConnection:
/**
* The key used to merge rows in the output table.
*/
val UPSERT_MERGE_KEY = "ARCANE_MERGE_KEY"

/**
* The key used to partition the output table by date.
*/
val DATE_PARTITION_KEY = "DATE_PARTITION_KEY"

/**
* Creates a new Microsoft SQL Server connection.
*
* @param connectionOptions The connection options for the database.
* @return A new Microsoft SQL Server connection.
*/
def apply(connectionOptions: ConnectionOptions): MsSqlConnection = new MsSqlConnection(connectionOptions)

object QueryProvider:
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global

/**
* Gets the schema query for the Microsoft SQL Server database.
*
* @param msSqlConnection The connection to the database.
* @return A future containing the schema query for the Microsoft SQL Server database.
*/
def getSchemaQuery(msSqlConnection: MsSqlConnection): Future[MsSqlQuery] =
msSqlConnection.getColumnSummaries
.map(columnSummaries => {
val mergeExpression = QueryProvider.getMergeExpression(columnSummaries, "tq")
val columnExpression = QueryProvider.getChangeTrackingColumns(columnSummaries, "tq", "sq")
val matchStatement = QueryProvider.getMatchStatement(columnSummaries, "sq", "tq", None)
QueryProvider.getChangesQuery(
msSqlConnection.connectionOptions,
mergeExpression,
columnExpression,
matchStatement,
Long.MaxValue)
})

/**
* Gets the column summaries query for the Microsoft SQL Server database.
*
* @param schemaName The name of the schema.
* @param tableName The name of the table.
* @param databaseName The name of the database.
* @return The column summaries query for the Microsoft SQL Server database.
*/
def getColumnSummariesQuery(schemaName: String, tableName: String, databaseName: String): MsSqlQuery =
Source.fromResource("get_column_summaries.sql")
.getLines
.mkString("\n")
.replace("{dbName}", databaseName)
.replace("{schema}", schemaName)
.replace("{table}", tableName)

private def getMergeExpression(cs: List[ColumnSummary], tableAlias: String): String =
cs.filter((name, isPrimaryKey) => isPrimaryKey)
.map((name, _) => s"cast($tableAlias.[$name] as nvarchar(128))")
.mkString(" + '#' + ")

private def getMatchStatement(cs: List[ColumnSummary], sourceAlias: String, outputAlias: String, partitionColumns: Option[List[String]]): String =
val mainMatch = cs.filter((_, isPrimaryKey) => isPrimaryKey)
.map((name, _) => s"$outputAlias.[$name] = $sourceAlias.[$name]")
.mkString(" and ")

partitionColumns match
case Some(columns) =>
val partitionMatch = columns
.map(column => s"$outputAlias.[$column] = $sourceAlias.[$column]")
.mkString(" and ")
s"$mainMatch and ($sourceAlias.SYS_CHANGE_OPERATION == 'D' OR ($partitionMatch))"
case None => mainMatch


private def getChangeTrackingColumns(tableColumns: List[ColumnSummary], changesAlias: String, tableAlias: String): String =
val primaryKeyColumns = tableColumns.filter((_, isPrimaryKey) => isPrimaryKey).map((name, _) => s"$changesAlias.[$name]")
val additionalColumns = List(s"$changesAlias.SYS_CHANGE_VERSION", s"$changesAlias.SYS_CHANGE_OPERATION")
val nonPrimaryKeyColumns = tableColumns
.filter((name, isPrimaryKey) => !isPrimaryKey && !Set("SYS_CHANGE_VERSION", "SYS_CHANGE_OPERATION").contains(name))
.map((name, _) => s"$tableAlias.[$name]")
(primaryKeyColumns ++ additionalColumns ++ nonPrimaryKeyColumns).mkString(",\n")

private def getChangesQuery(connectionOptions: ConnectionOptions,
mergeExpression: String,
columnStatement: String,
matchStatement: String,
changeTrackingId: Long): String =
val baseQuery = connectionOptions.partitionExpression match {
case Some(_) => Source.fromResource("get_select_delta_query_date_partitioned.sql").getLines.mkString("\n")
case None => Source.fromResource("get_select_delta_query.sql").getLines.mkString("\n")
}

baseQuery.replace("{dbName}", connectionOptions.databaseName)
.replace("{schema}", connectionOptions.schemaName)
.replace("{tableName}", connectionOptions.tableName)
.replace("{ChangeTrackingColumnsStatement}", columnStatement)
.replace("{ChangeTrackingMatchStatement}", matchStatement)
.replace("{MERGE_EXPRESSION}", mergeExpression)
.replace("{MERGE_KEY}", UPSERT_MERGE_KEY)
.replace("{DATE_PARTITION_EXPRESSION}", connectionOptions.partitionExpression.getOrElse(""))
.replace("{DATE_PARTITION_KEY}", DATE_PARTITION_KEY)
.replace("{lastId}", changeTrackingId.toString)
Original file line number Diff line number Diff line change
@@ -1,29 +1,54 @@
package com.sneaksanddata.arcane.framework
package services.connectors.mssql

import services.connectors.mssql.DbServer.{createDb, removeDb}
import services.mssql.{ConnectionOptions, MsSqlConnection, QueryProvider}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
import org.scalatest.*
import org.scalatest.flatspec.FixtureAnyFlatSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.matchers.should.Matchers.*

import java.sql.Connection
import java.util.Properties
import scala.concurrent.Future
import scala.language.postfixOps

object DbServer:
private val connectionUrl = "jdbc:sqlserver://localhost;encrypt=true;trustServerCertificate=true;username=sa;password=tMIxN11yGZgMC"
case class TestConnectionInfo(connectionOptions: ConnectionOptions, connection: Connection)

def createDb(): Connection = {
class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
val connectionUrl = "jdbc:sqlserver://localhost;encrypt=true;trustServerCertificate=true;username=sa;password=tMIxN11yGZgMC"

def createDb(): TestConnectionInfo =
val dr = new SQLServerDriver()
val con = dr.connect(connectionUrl, new Properties())
val query = "IF NOT EXISTS (SELECT * FROM sys.databases WHERE name = 'arcane') BEGIN CREATE DATABASE arcane; alter database Arcane set CHANGE_TRACKING = ON (CHANGE_RETENTION = 2 DAYS, AUTO_CLEANUP = ON); END;"
val statement = con.createStatement()
statement.execute(query)
con
}
createTable(con)
TestConnectionInfo(
ConnectionOptions(
connectionUrl,
"arcane",
"dbo",
"MsSqlConnectorsTests",
Some("format(getdate(), 'yyyyMM')")), con)

def createTable(con: Connection): Unit =
val query = "use arcane; drop table if exists dbo.MsSqlConnectorsTests; create table dbo.MsSqlConnectorsTests(x int not null, y int)"
val statement = con.createStatement()
statement.executeUpdate(query)

val createPKCmd = "use arcane; alter table dbo.MsSqlConnectorsTests add constraint pk_MsSqlConnectorsTests primary key(x);";
statement.executeUpdate(createPKCmd)

val enableCtCmd = "use arcane; alter table dbo.MsSqlConnectorsTests enable change_tracking;";
statement.executeUpdate(enableCtCmd)

for i <- 1 to 10 do
val insertCmd = s"use arcane; insert into dbo.MsSqlConnectorsTests values($i, ${i+1})"
statement.execute(insertCmd)


def removeDb(): Unit =
val query = "DROP DATABASE arcane"
Expand All @@ -32,32 +57,24 @@ object DbServer:
val statement = con.createStatement()
statement.execute(query)

trait DbFixture:

this: FixtureTestSuite =>

type FixtureParam = Connection

// Allow clients to populate the database after
// it is created
def populateDb(db: Connection): Unit = {}

def withFixture(test: OneArgTest): Outcome =
def withDatabase(test: TestConnectionInfo => Future[Assertion]): Future[Assertion] =
val conn = createDb()
try {
populateDb(conn)
withFixture(test.toNoArgTest(conn)) // "loan" the fixture to the test
}
finally removeDb() // clean up the fixture



class MsSqlConnectorsTests extends flatspec.FixtureAnyFlatSpec with Matchers with DbFixture:

override def populateDb(db: Connection): Unit =
println("ScalaTest is ")
test(conn)

"QueryProvider" should "generate columns query" in withDatabase { dbInfo =>
val connector = MsSqlConnection(dbInfo.connectionOptions)
val query = QueryProvider.getColumnSummariesQuery(connector.connectionOptions.schemaName,
connector.connectionOptions.tableName,
connector.connectionOptions.databaseName)
query should include ("case when kcu.CONSTRAINT_NAME is not null then 1 else 0 end as IsPrimaryKey")
}

"Testing" should "be easy" in { db =>
println("easy!")
assert(true)
"QueryProvider" should "generate schema query" in withDatabase { dbInfo =>
val connector = MsSqlConnection(dbInfo.connectionOptions)
QueryProvider.GetSchemaQuery(connector) map { query =>
query should (
include ("tq.SYS_CHANGE_VERSION") and include ("ARCANE_MERGE_KEY") and include("format(getdate(), 'yyyyMM')")
)
}
}
Loading