Skip to content

Commit

Permalink
Implement SQL server type parsing (#72)
Browse files Browse the repository at this point in the history
* Implement SQL server type parsing

## Scope:
- Implement intermediate abstraction ArcaneSchema for data schema translation
- Add ability to read columns schema with types to `MsSqlConnection`

* Remove the StructType as it is not used right now

* Review fixes
  • Loading branch information
s-vitaliy authored Nov 4, 2024
1 parent c1cc7cf commit b27b6da
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.sneaksanddata.arcane.framework
package models

/**
* ArcaneSchema is a type alias for a sequence of fields or structs.
*/
type ArcaneSchema = Seq[Field]

/**
* Companion object for ArcaneSchema.
*/
object ArcaneSchema:
/**
* Creates an empty ArcaneSchema.
* @return An empty ArcaneSchema.
*/
def empty(): ArcaneSchema = Seq.empty


/**
* Types of fields in ArcaneSchema.
*/
enum ArcaneType:
case LongType
case ByteArrayType
case BooleanType
case StringType
case DateType
case TimestampType
case DateTimeOffsetType
case BigDecimalType
case DoubleType
case IntType
case FloatType
case ShortType
case TimeType

/**
* Field is a case class that represents a field in ArcaneSchema
* @param name The name of the field.
* @param fieldType The type of the field.
*/
case class Field(name: String, fieldType: ArcaneType)
Original file line number Diff line number Diff line change
@@ -1,19 +1,41 @@
package com.sneaksanddata.arcane.framework
package services.base

import models.ArcaneType

import scala.concurrent.Future

/**
* Type class that represents the ability to add a field to a schema.
* @tparam Schema The type of the schema.
*/
trait CanAdd[Schema] :
/**
* Adds a field to the schema.
*
* @return The schema with the field added.
*/
extension (a: Schema) def addField(fieldName: String, fieldType: ArcaneType): Schema

/**
* Represents a provider of a schema for a data produced by Arcane.
*
* @tparam Schema The type of the schema.
*/
trait SchemaProvider[Schema] {
trait SchemaProvider[Schema: CanAdd] {

type SchemaType = Schema
/**
* Gets the schema for the data produced by Arcane.
*
* @return A future containing the schema for the data produced by Arcane.
*/
def getSchema: Future[Schema]
def getSchema: Future[SchemaType]

/**
* Gets an empty schema.
*
* @return An empty schema.
*/
def empty: SchemaType
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.sneaksanddata.arcane.framework
package services.mssql

import services.base.SchemaProvider
import models.{ArcaneSchema, ArcaneType, Field}
import services.base.{CanAdd, SchemaProvider}
import services.mssql.MsSqlConnection.{DATE_PARTITION_KEY, UPSERT_MERGE_KEY}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
import io.delta.kernel.types.{IntegerType, StructType}

import java.sql.ResultSet
import java.util.Properties
Expand Down Expand Up @@ -45,12 +45,18 @@ case class ConnectionOptions(connectionUrl: String,
tableName: String,
partitionExpression: Option[String])

/**
* Required typeclass implementation
*/
given CanAdd[ArcaneSchema] with
extension (a: ArcaneSchema) def addField(fieldName: String, fieldType: ArcaneType): ArcaneSchema = a :+ Field(fieldName, fieldType)

/**
* Represents a connection to a Microsoft SQL Server database.
*
* @param connectionOptions The connection options for the database.
*/
class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoCloseable with SchemaProvider[StructType]:
class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoCloseable with SchemaProvider[ArcaneSchema]:
private val driver = new SQLServerDriver()
private val connection = driver.connect(connectionOptions.connectionUrl, new Properties())
private implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
Expand Down Expand Up @@ -79,15 +85,22 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
*/
override def close(): Unit = connection.close()

/**
* Gets an empty schema.
*
* @return An empty schema.
*/
override def empty: this.SchemaType = ArcaneSchema.empty()

/**
* Gets the schema for the data produced by Arcane.
*
* @return A future containing the schema for the data produced by Arcane.
*/
override def getSchema: Future[StructType] =
override def getSchema: Future[this.SchemaType] =
for query <- QueryProvider.getSchemaQuery(this)
sqlSchema <- getSqlSchema(query)
yield toSchema(sqlSchema, StructType())
yield toSchema(sqlSchema, empty)

private def getSqlSchema(query: String): Future[SqlSchema] = Future {
val columns = Using.Manager { use =>
Expand All @@ -102,10 +115,12 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
}

@tailrec
private def toSchema(sqlSchema: SqlSchema, schema: StructType): StructType =
private def toSchema(sqlSchema: SqlSchema, schema: this.SchemaType): this.SchemaType =
sqlSchema match
case Nil => schema
case x +: xs => toSchema(xs, schema.add(x._1, IntegerType.INTEGER))
case x +: xs =>
val (name, fieldType) = x
toSchema(xs, schema.addField(name, toArcaneType(fieldType)))

@tailrec
private def readColumns(resultSet: ResultSet, result: List[ColumnSummary]): List[ColumnSummary] =
Expand All @@ -115,6 +130,22 @@ class MsSqlConnection(val connectionOptions: ConnectionOptions) extends AutoClos
return result
readColumns(resultSet, result ++ List((resultSet.getString(1), resultSet.getInt(2) == 1)))

private def toArcaneType(sqlType: Int): ArcaneType = sqlType match
case java.sql.Types.BIGINT => ArcaneType.LongType
case java.sql.Types.BINARY => ArcaneType.ByteArrayType
case java.sql.Types.BIT => ArcaneType.BooleanType
case java.sql.Types.CHAR => ArcaneType.StringType
case java.sql.Types.DATE => ArcaneType.DateType
case java.sql.Types.TIMESTAMP => ArcaneType.TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => ArcaneType.DateTimeOffsetType
case java.sql.Types.DECIMAL => ArcaneType.BigDecimalType
case java.sql.Types.DOUBLE => ArcaneType.DoubleType
case java.sql.Types.INTEGER => ArcaneType.IntType
case java.sql.Types.FLOAT => ArcaneType.FloatType
case java.sql.Types.SMALLINT => ArcaneType.ShortType
case java.sql.Types.TIME => ArcaneType.TimeType
case java.sql.Types.NCHAR => ArcaneType.StringType
case java.sql.Types.NVARCHAR => ArcaneType.StringType

object MsSqlConnection:
/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.sneaksanddata.arcane.framework
package services.connectors.mssql

import models.ArcaneType.{IntType, LongType, StringType}
import models.Field
import services.mssql.{ConnectionOptions, MsSqlConnection, QueryProvider}

import com.microsoft.sqlserver.jdbc.SQLServerDriver
Expand All @@ -10,7 +12,7 @@ import org.scalatest.matchers.should.Matchers.*

import java.sql.Connection
import java.util.Properties
import scala.jdk.CollectionConverters.ListHasAsScala
import scala.List
import scala.concurrent.Future
import scala.language.postfixOps

Expand Down Expand Up @@ -40,10 +42,10 @@ class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:
val statement = con.createStatement()
statement.executeUpdate(query)

val createPKCmd = "use arcane; alter table dbo.MsSqlConnectorsTests add constraint pk_MsSqlConnectorsTests primary key(x);";
val createPKCmd = "use arcane; alter table dbo.MsSqlConnectorsTests add constraint pk_MsSqlConnectorsTests primary key(x);"
statement.executeUpdate(createPKCmd)

val enableCtCmd = "use arcane; alter table dbo.MsSqlConnectorsTests enable change_tracking;";
val enableCtCmd = "use arcane; alter table dbo.MsSqlConnectorsTests enable change_tracking;"
statement.executeUpdate(enableCtCmd)

for i <- 1 to 10 do
Expand Down Expand Up @@ -81,10 +83,18 @@ class MsSqlConnectorsTests extends flatspec.AsyncFlatSpec with Matchers:
}

"MsSqlConnection" should "be able to extract schema column names from the database" in withDatabase { dbInfo =>
val dataColumns = List("x", "y")
val generatedColumns = List("SYS_CHANGE_VERSION", "SYS_CHANGE_OPERATION", "ChangeTrackingVersion", "ARCANE_MERGE_KEY", "DATE_PARTITION_KEY")
val connection = MsSqlConnection(dbInfo.connectionOptions)
connection.getSchema map { schema =>
schema.fields.asScala map { f => f.getName } should contain theSameElementsAs dataColumns ++ generatedColumns
val fields = for column <- schema if column.isInstanceOf[Field] yield column.name
fields should be (List("x", "SYS_CHANGE_VERSION", "SYS_CHANGE_OPERATION", "y", "ChangeTrackingVersion", "ARCANE_MERGE_KEY", "DATE_PARTITION_KEY"))
}
}


"MsSqlConnection" should "be able to extract schema column types from the database" in withDatabase { dbInfo =>
val connection = MsSqlConnection(dbInfo.connectionOptions)
connection.getSchema map { schema =>
val fields = for column <- schema if column.isInstanceOf[Field] yield column.fieldType
fields should be(List(IntType, LongType, StringType, IntType, LongType, StringType, StringType))
}
}

0 comments on commit b27b6da

Please sign in to comment.