Skip to content

Commit

Permalink
refactor: CLIN-3037 Enrich genes with SpliceAI in EnrichedGenes ETL
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Bégin committed Jul 31, 2024
1 parent 4a1b6b2 commit 57f0c95
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import java.time.LocalDateTime

/**
* This ETL create an aggregated table on occurrences of SNV variants. Occurrences are aggregated by calculating the frequencies specified in parameter frequencies.
* The table is enriched with information from other datasets such as genes, dbsnp, clinvar, spliceai, 1000 genomes, topmed_bravo, gnomad_genomes_v2, gnomad_exomes_v2, gnomad_genomes_v3.
* The table is enriched with information from other datasets such as genes, dbsnp, clinvar, 1000 genomes, topmed_bravo, gnomad_genomes_v2, gnomad_exomes_v2, gnomad_genomes_v3.
*
* @param participantId column used to distinct participants in order to calculate total number of participants (pn) and total allele number (an)
* @param affectedStatus column used to calculate frequencies for affected / unaffected participants
Expand All @@ -40,7 +40,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
protected val dbsnp: DatasetConf = conf.getDataset("normalized_dbsnp")
protected val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
protected val genes: DatasetConf = conf.getDataset("enriched_genes")
protected val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")
protected val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")

override def extract(lastRunValue: LocalDateTime = minValue,
Expand All @@ -54,7 +53,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
dbsnp.id -> dbsnp.read,
clinvar.id -> clinvar.read,
genes.id -> genes.read,
spliceai.id -> spliceai.read,
cosmic.id -> cosmic.read,
snvDatasetId -> conf.getDataset(snvDatasetId).read
)
Expand Down Expand Up @@ -85,7 +83,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
.withDbSNP(data(dbsnp.id))
.withClinvar(data(clinvar.id))
.withGenes(data(genes.id))
.withSpliceAi(data(spliceai.id))
.withCosmic(data(cosmic.id))
.withGeneExternalReference
.withVariantExternalReference
Expand Down Expand Up @@ -217,27 +214,6 @@ object Variants {
}
}

def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

val scores = spliceai.selectLocus($"symbol", $"max_score" as "spliceai")
.withColumn("type", when($"spliceai.ds" === 0, null).otherwise($"spliceai.type"))
.withColumn("spliceai", struct($"spliceai.ds" as "ds", $"type"))
.drop("type")

df
.select($"*", explode_outer($"genes") as "gene", $"gene.symbol" as "symbol") // explode_outer since genes can be null
.join(scores, locusColumnNames :+ "symbol", "left")
.drop("symbol") // only used for joining
.withColumn("gene", struct($"gene.*", $"spliceai")) // add spliceai struct as nested field of gene struct
.groupByLocus()
.agg(
first(struct(df.drop("genes")("*"))) as "variant",
collect_list("gene") as "genes" // re-create genes list for each locus, now containing spliceai struct
)
.select("variant.*", "genes")
}

def withCosmic(cosmic: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package bio.ferlab.datalake.spark3.publictables.enriched
import bio.ferlab.datalake.commons.config.{Coalesce, DatasetConf, RuntimeETLContext}
import bio.ferlab.datalake.spark3.etl.v4.SimpleSingleETL
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._
import bio.ferlab.datalake.spark3.implicits.GenomicImplicits._
import bio.ferlab.datalake.spark3.implicits.GenomicImplicits.columns.locusColumnNames
import bio.ferlab.datalake.spark3.implicits.SparkUtils.removeEmptyObjectsIn
import bio.ferlab.datalake.spark3.publictables.enriched.Genes._
import mainargs.{ParserForMethods, main}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}

import java.time.LocalDateTime

Expand All @@ -20,6 +23,7 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
Expand All @@ -30,7 +34,8 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
human_genes.id -> human_genes.read,
ddd_gene_set.id -> ddd_gene_set.read,
cosmic_gene_set.id -> cosmic_gene_set.read,
gnomad_constraint.id -> gnomad_constraint.read
gnomad_constraint.id -> gnomad_constraint.read,
spliceai.id -> spliceai.read
)
}

Expand All @@ -55,11 +60,41 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.withDDD(data(ddd_gene_set.id))
.withCosmic(data(cosmic_gene_set.id))
.withGnomadConstraint(data(gnomad_constraint.id))
.withSpliceAi(data(spliceai.id))
}

override def defaultRepartition: DataFrame => DataFrame = Coalesce()
}

object Genes {
@main
def run(rc: RuntimeETLContext): Unit = {
Genes(rc).run()
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)

implicit class DataFrameOps(df: DataFrame) {

def joinAndMergeWith(gene_set: DataFrame,
joinOn: Seq[String],
asColumnName: String,
aggFirst: Boolean = false): DataFrame = {
val aggFn: Column => Column = c => if (aggFirst) first(c) else collect_list(c)
val aggDF = df
.join(gene_set, joinOn, "left")
.groupBy("symbol")
.agg(
first(struct(df("*"))) as "hg",
aggFn(struct(gene_set.drop(joinOn: _*)("*"))) as asColumnName,
)
.select(col("hg.*"), col(asColumnName))
if (aggFirst)
aggDF
else
aggDF.withColumn(asColumnName, removeEmptyObjectsIn(asColumnName))
}

def withGnomadConstraint(gnomad: DataFrame): DataFrame = {
val gnomadConstraint = gnomad
.groupBy("chromosome", "symbol")
Expand Down Expand Up @@ -105,34 +140,17 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo")
}

def joinAndMergeWith(gene_set: DataFrame, joinOn: Seq[String], asColumnName: String, aggFirst: Boolean = false): DataFrame = {
val aggFn: Column => Column = c => if (aggFirst) first(c) else collect_list(c)
val aggDF = df
.join(gene_set, joinOn, "left")
.groupBy("symbol")
.agg(
first(struct(df("*"))) as "hg",
aggFn(struct(gene_set.drop(joinOn: _*)("*"))) as asColumnName,
)
.select(col("hg.*"), col(asColumnName))
if (aggFirst)
aggDF
else
aggDF.withColumn(asColumnName, removeEmptyObjectsIn(asColumnName))
}
def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

}


override def defaultRepartition: DataFrame => DataFrame = Coalesce()
}
val spliceAiPrepared = spliceai
.select($"symbol", $"max_score.*")
.withColumn("type", when($"ds" === 0, null).otherwise($"type"))

object Genes {
@main
def run(rc: RuntimeETLContext): Unit = {
Genes(rc).run()
df
.joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true)
.withColumn("spliceai", when($"spliceai.ds".isNull and $"spliceai.type".isNull, null).otherwise($"spliceai"))
}
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package bio.ferlab.datalake.spark3.genomics.enriched
import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.genomics.enriched.Variants.DataFrameOps
import bio.ferlab.datalake.spark3.genomics.{FrequencySplit, SimpleAggregation}
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.enriched.EnrichedVariant.CMC
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedSpliceAi, EnrichedVariant, MAX_SCORE}
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedVariant}
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.normalized.NormalizedCosmicMutationSet
import bio.ferlab.datalake.testutils.{SparkSpec, TestETLContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, collect_set, max}
Expand All @@ -26,7 +25,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val dbsnp: DatasetConf = conf.getDataset("normalized_dbsnp")
val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
val genes: DatasetConf = conf.getDataset("enriched_genes")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")
val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")

val occurrencesDf: DataFrame = Seq(
Expand All @@ -42,7 +40,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val dbsnpDf: DataFrame = Seq(NormalizedDbsnp()).toDF
val clinvarDf: DataFrame = Seq(NormalizedClinvar(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF
val genesDf: DataFrame = Seq(EnrichedGenes()).toDF()
val spliceaiDf: DataFrame = Seq(EnrichedSpliceAi(chromosome = "1", start = 69897, reference = "T", alternate = "C", symbol = "OR4F5", ds_ag = 0.01, `max_score` = MAX_SCORE(ds = 0.01, `type` = Seq("AG")))).toDF()
val cosmicDf: DataFrame = Seq(NormalizedCosmicMutationSet(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF()

val etl = Variants(TestETLContext(), snvDatasetId = snvKeyId, splits = Seq(FrequencySplit("frequency", extraAggregations = Seq(SimpleAggregation(name = "zygosities", c = col("zygosity"))))))
Expand All @@ -57,7 +54,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
dbsnp.id -> dbsnpDf,
clinvar.id -> clinvarDf,
genes.id -> genesDf,
spliceai.id -> spliceaiDf,
cosmic.id -> cosmicDf
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package bio.ferlab.datalake.spark3.publictables.enriched

import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, OMIM, ORPHANET, COSMIC}
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.spark3.publictables.enriched.Genes._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.normalized.NormalizedCosmicGeneSet
import bio.ferlab.datalake.testutils.models.enriched._
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.testutils.{CleanUpBeforeAll, CreateDatabasesBeforeAll, SparkSpec, TestETLContext}
import org.apache.spark.sql.functions
import org.apache.spark.sql.functions.col
Expand All @@ -22,6 +22,7 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

private val inputData = Map(
omim_gene_set.id -> Seq(
Expand All @@ -35,7 +36,8 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
gnomad_constraint.id -> Seq(
NormalizedGnomadConstraint(chromosome = "1", start = 69897, symbol = "OR4F5", `pLI` = 1.0f, oe_lof_upper = 0.01f),
NormalizedGnomadConstraint(chromosome = "1", start = 69900, symbol = "OR4F5", `pLI` = 0.9f, oe_lof_upper = 0.054f)
).toDF()
).toDF(),
spliceai.id -> Seq(EnrichedSpliceAi(`symbol` = "OR4F5")).toDF()
)

val job = new Genes(TestETLContext())
Expand Down Expand Up @@ -78,5 +80,36 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
resultDF.where("symbol='OR4F5'").as[EnrichedGenes].collect().head shouldBe
EnrichedGenes(`orphanet` = expectedOrphanet, `omim` = expectedOmim, `cosmic` = expectedCosmic)
}

"withSpliceAi" should "enrich genes with SpliceAi scores" in {
val genes = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1"),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2"),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3"),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4"),
).toDF().drop("spliceai")

val spliceai = Seq(
// snv
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `max_score` = MAX_SCORE(`ds` = 2.0, `type` = Seq("AL"))),
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene2", `max_score` = MAX_SCORE(`ds` = 0.0, `type` = Seq("AG", "AL", "DG", "DL"))),

// indel
EnrichedSpliceAi(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", `symbol` = "gene3", `max_score` = MAX_SCORE(`ds` = 1.0, `type` = Seq("AG", "AL")))
).toDF()

val result = genes.withSpliceAi(spliceai)

val expected = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1", `spliceai` = Some(SPLICEAI(`ds` = 2.0, `type` = Some(List("AL"))))),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2", `spliceai` = Some(SPLICEAI(`ds` = 0.0, `type` = None))),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3", `spliceai` = Some(SPLICEAI(`ds` = 1.0, `type` = Some(List("AG", "AL"))))),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4", `spliceai` = None),
)

result
.as[EnrichedGenes]
.collect() should contain theSameElementsAs expected
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ case class EnrichedGenes(`symbol`: String = "OR4F5",
`chromosome`: String = "1",
`ddd`: List[DDD] = List(DDD()),
`cosmic`: List[COSMIC] = List(COSMIC()),
gnomad: Option[GNOMAD] = Some(GNOMAD()))
gnomad: Option[GNOMAD] = Some(GNOMAD()),
spliceai: Option[SPLICEAI] = Some(SPLICEAI()))

case class ORPHANET(`disorder_id`: Long = 17827,
`panel`: String = "Immunodeficiency due to a classical component pathway complement deficiency",
Expand All @@ -40,3 +41,6 @@ case class COSMIC(`tumour_types_germline`: List[String] = List("breast", "colon"

case class GNOMAD(pli: Float = 1.0f,
loeuf: Float = 0.054f)

case class SPLICEAI(ds: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ case class EnrichedSpliceAi(`chromosome`: String = "1",
`alternate`: String = "G",
`allele`: String = "G",
`symbol`: String = "KCNH1",
`ds_ag`: Double = 0.0,
`ds_al`: Double = 0.0,
`ds_dg`: Double = 0.0,
`ds_dl`: Double = 0.0,
`ds_ag`: Double = 0.1,
`ds_al`: Double = 0.1,
`ds_dg`: Double = 0.1,
`ds_dl`: Double = 0.1,
`dp_ag`: Int = -3,
`dp_al`: Int = 36,
`dp_dg`: Int = 32,
`dp_dl`: Int = 22,
`max_score`: MAX_SCORE = MAX_SCORE())

case class MAX_SCORE(`ds`: Double = 0.0,
case class MAX_SCORE(`ds`: Double = 0.1,
`type`: Seq[String] = Seq("AG", "AL", "DG", "DL"))
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ object EnrichedVariant {
)


case class SPLICEAI(ds: Double = 0.01,
`type`: Seq[String] = Seq("AG"))
case class SPLICEAI(ds: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))

case class CMC(mutation_url: String = "https://cancer.sanger.ac.uk/cosmic/mutation/overview?id=29491889&genome=37",
shared_aa: Int = 9,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ case class NormalizedSpliceAi(`chromosome`: String = "2",
`alternate`: String = "G",
`allele`: String = "G",
`symbol`: String = "KCNH1",
`ds_ag`: Double = 0.0,
`ds_al`: Double = 0.0,
`ds_dg`: Double = 0.0,
`ds_dl`: Double = 0.0,
`ds_ag`: Double = 0.1,
`ds_al`: Double = 0.1,
`ds_dg`: Double = 0.1,
`ds_dl`: Double = 0.1,
`dp_ag`: Int = -3,
`dp_al`: Int = 36,
`dp_dg`: Int = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ object PreparedVariantCentric {
`hpo_term_name`: String = "Hyperreflexia",
`hpo_term_label`: String = "Hyperreflexia (HP:0001347)")

case class SPLICEAI(`ds`: Double = 0.01,
`type`: Seq[String] = Seq("AG"))
case class SPLICEAI(`ds`: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))

case class TOPMED_BRAVO(`ac`: Long = 2,
`an`: Long = 125568,
Expand Down

0 comments on commit 57f0c95

Please sign in to comment.