Skip to content

Commit

Permalink
Clin 3037 (#235)
Browse files Browse the repository at this point in the history
* perf: CLIN-3037 Split Enriched SpliceAI in indel and snv

* refactor: CLIN-3037 Move withSpliceAI from genes to variants
  • Loading branch information
Laura Bégin authored Aug 21, 2024
1 parent fa52e6e commit 94924fd
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 123 deletions.
43 changes: 39 additions & 4 deletions datalake-spark3/src/main/resources/reference_kf.conf
Original file line number Diff line number Diff line change
Expand Up @@ -1112,13 +1112,13 @@ datalake {
},
{
format=DELTA
id="enriched_spliceai"
id="enriched_spliceai_indel"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched"
path="/public/spliceai/enriched/indel"
readoptions {}
repartition {
column-names=[
Expand All @@ -1131,11 +1131,46 @@ datalake {
storageid="public_database"
table {
database=variant
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
view {
database="variant_live"
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
writeoptions {
"created_on_column"="created_on"
"is_current_column"="is_current"
"updated_on_column"="updated_on"
"valid_from_column"="valid_from"
"valid_to_column"="valid_to"
}
},
{
format=DELTA
id="enriched_spliceai_snv"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched/snv"
readoptions {}
repartition {
column-names=[
chromosome,
start
]
kind=RepartitionByRange
sort-columns=[]
}
storageid="public_database"
table {
database=variant
name="spliceai_enriched_snv"
}
view {
database="variant_live"
name="spliceai_enriched_snv"
}
writeoptions {
"created_on_column"="created_on"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import bio.ferlab.datalake.spark3.implicits.SparkUtils.firstAs
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}

import java.time.LocalDateTime
Expand All @@ -25,9 +26,13 @@ import java.time.LocalDateTime
* @param snvDatasetId the id of the dataset containing the SNV variants
* @param frequencies the frequencies to calculate. See [[FrequencyOperations.freq]]
* @param extraAggregations extra aggregations to be computed when grouping occurrences by locus. Will be added to the root of the data
* @param spliceAi bool indicating whether or not to join variants with SpliceAI. Defaults to true.
* @param rc the etl context
*/
case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"), affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")), snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil, checkpoint: Boolean = false) extends SimpleSingleETL(rc) {
case class Variants(rc: RuntimeETLContext, participantId: Column = col("participant_id"),
affectedStatus: Column = col("affected_status"), filterSnv: Option[Column] = Some(col("has_alt")),
snvDatasetId: String, splits: Seq[OccurrenceSplit], extraAggregations: Seq[Column] = Nil,
checkpoint: Boolean = false, spliceAi: Boolean = true) extends SimpleSingleETL(rc) {
override val mainDestination: DatasetConf = conf.getDataset("enriched_variants")
if (checkpoint) {
spark.sparkContext.setCheckpointDir(s"${mainDestination.rootPath}/checkpoints")
Expand All @@ -41,6 +46,8 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
protected val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
protected val genes: DatasetConf = conf.getDataset("enriched_genes")
protected val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")
protected val spliceai_indel: DatasetConf = conf.getDataset("enriched_spliceai_indel")
protected val spliceai_snv: DatasetConf = conf.getDataset("enriched_spliceai_snv")

override def extract(lastRunValue: LocalDateTime = minValue,
currentRunValue: LocalDateTime = LocalDateTime.now()): Map[String, DataFrame] = {
Expand All @@ -54,6 +61,8 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
clinvar.id -> clinvar.read,
genes.id -> genes.read,
cosmic.id -> cosmic.read,
spliceai_indel.id -> (if (spliceAi) spliceai_indel.read else spark.emptyDataFrame),
spliceai_snv.id -> (if (spliceAi) spliceai_snv.read else spark.emptyDataFrame),
snvDatasetId -> conf.getDataset(snvDatasetId).read
)
}
Expand Down Expand Up @@ -84,14 +93,13 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
.withClinvar(data(clinvar.id))
.withGenes(data(genes.id))
.withCosmic(data(cosmic.id))
.withSpliceAi(snv = data(spliceai_snv.id), indel = data(spliceai_indel.id), compute = spliceAi)
.withGeneExternalReference
.withVariantExternalReference
.withColumn("locus", concat_ws("-", locus: _*))
.withColumn("hash", sha1(col("locus")))
.drop("genes_symbol")
}


}

object Variants {
Expand Down Expand Up @@ -235,6 +243,53 @@ object Variants {

df.joinAndMerge(cmc, "cmc", "left")
}

def withSpliceAi(snv: DataFrame, indel: DataFrame, compute: Boolean = true)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

def joinAndMergeIntoGenes(variants: DataFrame, spliceai: DataFrame): DataFrame = {
if (!variants.isEmpty) {
variants
.select($"*", explode_outer($"genes") as "gene", $"gene.symbol" as "symbol") // explode_outer since genes can be null
.join(spliceai, locusColumnNames :+ "symbol", "left")
.drop("symbol") // only used for joining
.withColumn("gene", struct($"gene.*", $"spliceai")) // add spliceai struct as nested field of gene struct
.groupByLocus()
.agg(
first(struct(variants.drop("genes")("*"))) as "variant",
collect_list("gene") as "genes" // re-create genes list for each locus, now containing spliceai struct
)
.select("variant.*", "genes")
} else variants
}

if (compute) {
val spliceAiSnvPrepared = snv
.selectLocus($"symbol", $"max_score" as "spliceai")

val spliceAiIndelPrepared = indel
.selectLocus($"symbol", $"max_score" as "spliceai")

val snvVariants = df
.where($"variant_class" === "SNV")

val otherVariants = df
.where($"variant_class" =!= "SNV")

val snvVariantsWithSpliceAi = joinAndMergeIntoGenes(snvVariants, spliceAiSnvPrepared)
val otherVariantsWithSpliceAi = joinAndMergeIntoGenes(otherVariants, spliceAiIndelPrepared)

snvVariantsWithSpliceAi.unionByName(otherVariantsWithSpliceAi, allowMissingColumns = true)
} else {
// Add empty spliceai struct
df
.withColumn("genes", transform($"genes", g => g.withField("spliceai", lit(null).cast(
StructType(Seq(
StructField("ds", DoubleType),
StructField("type", ArrayType(StringType))
))))))
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ object ImportPublicTable {
def spliceai_snv(rc: RuntimeETLContext): Unit = SpliceAi.run(rc, "snv")

@main
def spliceai_enriched(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc)
def spliceai_enriched_indel(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc, "indel")

@main
def spliceai_enriched_snv(rc: RuntimeETLContext): Unit = enriched.SpliceAi.run(rc, "snv")

@main
def topmed_bravo(rc: RuntimeETLContext): Unit = TopMed.run(rc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ case class PublicDatasets(alias: String, tableDatabase: Option[String], viewData
DatasetConf("normalized_spliceai_snv" , alias, "/public/spliceai/snv" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("spliceai_snv") , view = view("spliceai_snv")),

// enriched
DatasetConf("enriched_genes" , alias, "/public/genes" , DELTA, OverWrite , partitionby = List() , table = table("genes") , view = view("genes")),
DatasetConf("enriched_dbnsfp" , alias, "/public/dbnsfp/scores" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("dbnsfp_original") , view = view("dbnsfp_original")),
DatasetConf("enriched_spliceai" , alias, "/public/spliceai/enriched" , DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched") , view = view("spliceai_enriched")),
DatasetConf("enriched_rare_variant" , alias, "/public/rare_variant/enriched", DELTA, OverWrite , partitionby = List("chromosome", "is_rare"), table = table("rare_variant_enriched"), view = view("rare_variant_enriched"))
DatasetConf("enriched_genes" , alias, "/public/genes" , DELTA, OverWrite , partitionby = List() , table = table("genes") , view = view("genes")),
DatasetConf("enriched_dbnsfp" , alias, "/public/dbnsfp/scores" , DELTA, OverWrite , partitionby = List("chromosome"), table = table("dbnsfp_original") , view = view("dbnsfp_original")),
DatasetConf("enriched_spliceai_indel" , alias, "/public/spliceai/enriched/indel", DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched_indel"), view = view("spliceai_enriched_indel")),
DatasetConf("enriched_spliceai_snv" , alias, "/public/spliceai/enriched/snv" , DELTA, OverWrite , partitionby = List("chromosome"), repartition= Some(RepartitionByRange(columnNames = Seq("chromosome", "start"))), table = table("spliceai_enriched_snv") , view = view("spliceai_enriched_snv")),
DatasetConf("enriched_rare_variant" , alias, "/public/rare_variant/enriched" , DELTA, OverWrite , partitionby = List("chromosome", "is_rare"), table = table("rare_variant_enriched"), view = view("rare_variant_enriched"))

)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
Expand All @@ -32,8 +31,7 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
human_genes.id -> human_genes.read,
ddd_gene_set.id -> ddd_gene_set.read,
cosmic_gene_set.id -> cosmic_gene_set.read,
gnomad_constraint.id -> gnomad_constraint.read,
spliceai.id -> spliceai.read
gnomad_constraint.id -> gnomad_constraint.read
)
}

Expand All @@ -58,7 +56,6 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.withDDD(data(ddd_gene_set.id))
.withCosmic(data(cosmic_gene_set.id))
.withGnomadConstraint(data(gnomad_constraint.id))
.withSpliceAi(data(spliceai.id))
}

override def defaultRepartition: DataFrame => DataFrame = Coalesce()
Expand Down Expand Up @@ -138,20 +135,6 @@ object Genes {
.withColumn("hpo_term_label", concat(col("hpo_term_name"), lit(" ("), col("hpo_term_id"), lit(")")))
df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo", broadcastOtherDf = true)
}

def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

val spliceAiPrepared = spliceai
.groupBy("symbol")
.agg(first("max_score") as "max_score")
.select($"symbol", $"max_score.*")
.withColumn("type", when($"ds" === 0, null).otherwise($"type"))

df
.joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true, broadcastOtherDf = true)
.withColumn("spliceai", when($"spliceai.ds".isNull and $"spliceai.type".isNull, null).otherwise($"spliceai"))
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,29 @@ package bio.ferlab.datalake.spark3.publictables.enriched
import bio.ferlab.datalake.commons.config.{DatasetConf, RepartitionByRange, RuntimeETLContext}
import bio.ferlab.datalake.spark3.etl.v4.SimpleSingleETL
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits.DatasetConfOperations
import mainargs.{ParserForMethods, main}
import mainargs.{ParserForMethods, arg, main}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame, functions}

import java.time.LocalDateTime

case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
case class SpliceAi(rc: RuntimeETLContext, variantType: String) extends SimpleSingleETL(rc) {

override val mainDestination: DatasetConf = conf.getDataset("enriched_spliceai")
val spliceai_indel: DatasetConf = conf.getDataset("normalized_spliceai_indel")
val spliceai_snv: DatasetConf = conf.getDataset("normalized_spliceai_snv")
override val mainDestination: DatasetConf = conf.getDataset(s"enriched_spliceai_$variantType")
val normalized_spliceai: DatasetConf = conf.getDataset(s"normalized_spliceai_$variantType")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
Map(
spliceai_indel.id -> spliceai_indel.read,
spliceai_snv.id -> spliceai_snv.read
)
Map(normalized_spliceai.id -> normalized_spliceai.read)
}

override def transformSingle(data: Map[String, DataFrame],
lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): DataFrame = {
import spark.implicits._

val spliceai_snvDf = data(spliceai_snv.id)
val spliceai_indelDf = data(spliceai_indel.id)

val originalColumns = spliceai_snvDf.columns.map(col)
val df = data(normalized_spliceai.id)
val originalColumns = df.columns.map(col)

val getDs: Column => Column = _.getItem(0).getField("ds") // Get delta score
val scoreColumnNames = Array("AG", "AL", "DG", "DL")
Expand All @@ -43,8 +37,7 @@ case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.otherwise(c2)
}

spliceai_snvDf
.union(spliceai_indelDf)
df
.select(
originalColumns :+
$"ds_ag".as("AG") :+ // acceptor gain
Expand All @@ -57,17 +50,17 @@ case class SpliceAi(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
getDs($"max_score_temp") as "ds",
functions.transform($"max_score_temp", c => c.getField("type")) as "type")
)
.withColumn("max_score", $"max_score".withField("type", when($"max_score.ds" === 0, null).otherwise($"max_score.type")))
.select(originalColumns :+ $"max_score": _*)
}

override def defaultRepartition: DataFrame => DataFrame = RepartitionByRange(columnNames = Seq("chromosome", "start"), n = Some(500))

}

object SpliceAi {
@main
def run(rc: RuntimeETLContext): Unit = {
SpliceAi(rc).run()
def run(rc: RuntimeETLContext, @arg(name = "variant_type", short = 'v', doc = "Variant Type") variantType: String): Unit = {
SpliceAi(rc, variantType).run()
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)
Expand Down
43 changes: 39 additions & 4 deletions datalake-spark3/src/test/resources/config/reference_kf.conf
Original file line number Diff line number Diff line change
Expand Up @@ -1112,13 +1112,13 @@ datalake {
},
{
format=DELTA
id="enriched_spliceai"
id="enriched_spliceai_indel"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched"
path="/public/spliceai/enriched/indel"
readoptions {}
repartition {
column-names=[
Expand All @@ -1131,11 +1131,46 @@ datalake {
storageid="public_database"
table {
database=variant
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
view {
database="variant_live"
name="spliceai_enriched"
name="spliceai_enriched_indel"
}
writeoptions {
"created_on_column"="created_on"
"is_current_column"="is_current"
"updated_on_column"="updated_on"
"valid_from_column"="valid_from"
"valid_to_column"="valid_to"
}
},
{
format=DELTA
id="enriched_spliceai_snv"
keys=[]
loadtype=OverWrite
partitionby=[
chromosome
]
path="/public/spliceai/enriched/snv"
readoptions {}
repartition {
column-names=[
chromosome,
start
]
kind=RepartitionByRange
sort-columns=[]
}
storageid="public_database"
table {
database=variant
name="spliceai_enriched_snv"
}
view {
database="variant_live"
name="spliceai_enriched_snv"
}
writeoptions {
"created_on_column"="created_on"
Expand Down
Loading

0 comments on commit 94924fd

Please sign in to comment.