Skip to content

Commit

Permalink
feat: UNIC-1512 Implement Jdbc Upsert
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Bégin committed May 30, 2024
1 parent 4a6a66c commit a5a884f
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 13 deletions.
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ lazy val `datalake-spark3` = (project in file("datalake-spark3"))
"com.softwaremill.sttp.client3" %% "slf4j-backend" % "3.9.2",
"com.dimafeng" %% "testcontainers-scala-scalatest" % "0.41.0" % Test,
"com.dimafeng" %% "testcontainers-scala-elasticsearch" % "0.41.2" % Test,
"com.dimafeng" %% "testcontainers-scala-postgresql" % "0.41.3" % Test,
"org.postgresql" % "postgresql" % "42.5.1" % Test,
"org.scalatest" %% "scalatest" % scalatestVersion % Test,
"org.apache.spark" %% "spark-hive" % spark3Version % Test,

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package bio.ferlab.datalake.spark3.loader
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

import org.apache.spark.sql.functions.coalesce
import org.apache.spark.sql.{Column, DataFrame, SaveMode, SparkSession}

import java.time.LocalDate

object JdbcLoader extends Loader {

/**
* Default read logic for a loader
*
* @param location absolute path of where the data is
* @param format string representing the format
* @param readOptions read options
Expand Down Expand Up @@ -36,15 +39,15 @@ object JdbcLoader extends Loader {
}

/**
* @param location where to write the data
* @param location where to write the data
* @param databaseName database name
* @param tableName table name
* @param df new data to write into the table
* @param tableName table name
* @param df new data to write into the table
* @param partitioning how the data is partitionned
* @param format format
* @param options write options
* @param spark a spark session
* @return updated data
* @param format format
* @param options write options
* @param spark a spark session
* @return updated data
*/
override def writeOnce(location: String,
databaseName: String,
Expand Down Expand Up @@ -91,11 +94,12 @@ object JdbcLoader extends Loader {
/**
* Update or insert data into a table
* Resolves duplicates by using the list of primary key passed as argument
* @param location full path of where the data will be located
* @param tableName the name of the updated/created table
* @param updates new data to be merged with existing data
*
* @param location full path of where the data will be located
* @param tableName the name of the updated/created table
* @param updates new data to be merged with existing data
* @param primaryKeys name of the columns holding the unique id
* @param spark a valid spark session
* @param spark a valid spark session
* @return the data as a dataframe
*/
override def upsert(location: String,
Expand All @@ -105,7 +109,30 @@ object JdbcLoader extends Loader {
primaryKeys: Seq[String],
partitioning: List[String],
format: String,
options: Map[String, String])(implicit spark: SparkSession): DataFrame = ???
options: Map[String, String])(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

require(primaryKeys.nonEmpty, "Primary keys are required for an Upsert write.")
require(primaryKeys.forall(updates.columns.contains), s"Columns [${primaryKeys.mkString(", ")}] are required in the DataFrame.")

val readOptions = options + ("dbtable" -> s"$databaseName.$tableName")
val existingDf = read(location, format, readOptions, Some(databaseName), Some(tableName))
.persist() // Make sure table is read once and at first

val updatedDf = if (existingDf.isEmpty) updates
else {
// No upsert operation with JDBC connection
// Do the merge with Spark then overwrite the table with the result
val keysAreIdentical: Column = primaryKeys.map(col => $"new.$col" <=> $"existing.$col").reduce(_ && _)
val updatedColumns: Seq[Column] = updates.columns.map(col => coalesce($"new.$col", $"existing.$col") as col)

updates.as("new")
.join(existingDf.as("existing"), keysAreIdentical, "full")
.select(updatedColumns: _*)
}

writeOnce(location, databaseName, tableName, updatedDf, partitioning, format, options)
}

/**
* Update the data only if the data has changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ object LoadResolver {
case (format, Insert) if format == JDBC || format == SQL_SERVER => (ds: DatasetConf, df: DataFrame) =>
JdbcLoader.insert(ds.location, ds.table.map(_.database).getOrElse(""), ds.table.map(_.name).getOrElse(""), df, ds.partitionby, ds.format.sparkFormat, ds.writeoptions)

case (format, Upsert) if format == JDBC || format == SQL_SERVER => (ds: DatasetConf, df: DataFrame) =>
JdbcLoader.upsert(ds.location, ds.table.map(_.database).getOrElse(""), ds.table.map(_.name).getOrElse(""), df, ds.keys, ds.partitionby, ds.format.sparkFormat, ds.writeoptions)

case (ELASTICSEARCH, OverWrite) => (ds: DatasetConf, df: DataFrame) =>
ElasticsearchLoader.writeOnce(ds.location, ds.table.map(_.database).getOrElse(""), ds.table.map(_.name).getOrElse(ds.location), df, ds.partitionby, ds.format.sparkFormat, ds.writeoptions)

Expand Down
12 changes: 12 additions & 0 deletions datalake-spark3/src/test/resources/jdbc/init-dbt.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
create schema test_schema

create table test_schema.test
(
uid varchar not null,
oid varchar not null,
createdOn timestamp not null,
updatedOn timestamp not null,
data bigint,
chromosome varchar(2) not null,
start bigint
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package bio.ferlab.datalake.spark3.loader

import bio.ferlab.datalake.testutils.SparkSpec
import com.dimafeng.testcontainers.scalatest.TestContainerForEach
import com.dimafeng.testcontainers.{JdbcDatabaseContainer, PostgreSQLContainer}
import org.scalatest.matchers.should.Matchers
import org.testcontainers.utility.DockerImageName

import java.sql.Timestamp
import java.time.LocalDateTime

class JdbcLoaderSpec extends SparkSpec with Matchers with TestContainerForEach {

import spark.implicits._

val databaseName = "test_db"
val schemaName = "test_schema"
val tableName = "test"

override val containerDef: PostgreSQLContainer.Def = PostgreSQLContainer.Def(
dockerImageName = DockerImageName.parse("postgres:15.1"),
databaseName = databaseName,
username = "scala",
password = "scala",
commonJdbcParams = JdbcDatabaseContainer.CommonParams(initScriptPath = Some("jdbc/init-dbt.sql"))
)

def withPostgresContainer(options: Map[String, String] => Unit): Unit = withContainers { psqlContainer =>
options(Map(
"url" -> psqlContainer.jdbcUrl,
"driver" -> "org.postgresql.Driver",
"user" -> psqlContainer.username,
"password" -> psqlContainer.password
))
}

"upsert" should "update existing data and insert new data" in {
withPostgresContainer { options =>
val readOptions = options + ("dbtable" -> s"$schemaName.$tableName")
val day1 = LocalDateTime.of(2020, 1, 1, 1, 1, 1)
val day2 = day1.plusDays(1)

// Write existing data in database
val existingData = Seq(
TestData(`uid` = "a", `oid` = "a", `chromosome` = "1", `createdOn` = Timestamp.valueOf(day1), `updatedOn` = Timestamp.valueOf(day1), `data` = 1),
TestData(`uid` = "aa", `oid` = "aa", `chromosome` = "2", `createdOn` = Timestamp.valueOf(day1), `updatedOn` = Timestamp.valueOf(day1), `data` = 1),
)
JdbcLoader.writeOnce("", schemaName, tableName, existingData.toDF(), List(), "jdbc", options)

// Check existing data was written in the database
val existingDfInPsql = JdbcLoader.read("", "jdbc", readOptions, Some(schemaName), Some(tableName))
existingDfInPsql
.as[TestData]
.collect() should contain theSameElementsAs existingData

// Upsert data
val upsertData = Seq(
TestData(`uid` = "aa", `oid` = "aa", `chromosome` = "1", `createdOn` = Timestamp.valueOf(day1), `updatedOn` = Timestamp.valueOf(day2), `data` = 2), // update
TestData(`uid` = "aaa", `oid` = "aaa", `chromosome` = "3", `createdOn` = Timestamp.valueOf(day2), `updatedOn` = Timestamp.valueOf(day2), `data` = 2), // insert
)
val upsertResult = JdbcLoader.upsert("", schemaName, tableName, upsertData.toDF(), primaryKeys = Seq("uid", "oid"), List(), "jdbc", options)
val updatedDfInPsql = JdbcLoader.read("", "jdbc", readOptions, Some(schemaName), Some(tableName))

val expectedData = Seq(
TestData(`uid` = "a", `oid` = "a", `chromosome` = "1", `createdOn` = Timestamp.valueOf(day1), `updatedOn` = Timestamp.valueOf(day1), `data` = 1),
TestData(`uid` = "aa", `oid` = "aa", `chromosome` = "1", `createdOn` = Timestamp.valueOf(day1), `updatedOn` = Timestamp.valueOf(day2), `data` = 2),
TestData(`uid` = "aaa", `oid` = "aaa", `chromosome` = "3", `createdOn` = Timestamp.valueOf(day2), `updatedOn` = Timestamp.valueOf(day2), `data` = 2),
)

upsertResult
.as[TestData]
.collect() should contain theSameElementsAs expectedData

updatedDfInPsql
.as[TestData]
.collect() should contain theSameElementsAs expectedData
}
}
}

0 comments on commit a5a884f

Please sign in to comment.