Skip to content

Commit

Permalink
refactor: CLIN-3037 Enrich genes with SpliceAI in EnrichedGenes ETL (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Bégin authored Jul 31, 2024
1 parent 4a1b6b2 commit 95bd2e2
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import java.time.LocalDateTime

/**
* This ETL create an aggregated table on occurrences of SNV variants. Occurrences are aggregated by calculating the frequencies specified in parameter frequencies.
* The table is enriched with information from other datasets such as genes, dbsnp, clinvar, spliceai, 1000 genomes, topmed_bravo, gnomad_genomes_v2, gnomad_exomes_v2, gnomad_genomes_v3.
* The table is enriched with information from other datasets such as genes, dbsnp, clinvar, 1000 genomes, topmed_bravo, gnomad_genomes_v2, gnomad_exomes_v2, gnomad_genomes_v3.
*
* @param participantId column used to distinct participants in order to calculate total number of participants (pn) and total allele number (an)
* @param affectedStatus column used to calculate frequencies for affected / unaffected participants
Expand All @@ -40,7 +40,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
protected val dbsnp: DatasetConf = conf.getDataset("normalized_dbsnp")
protected val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
protected val genes: DatasetConf = conf.getDataset("enriched_genes")
protected val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")
protected val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")

override def extract(lastRunValue: LocalDateTime = minValue,
Expand All @@ -54,7 +53,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
dbsnp.id -> dbsnp.read,
clinvar.id -> clinvar.read,
genes.id -> genes.read,
spliceai.id -> spliceai.read,
cosmic.id -> cosmic.read,
snvDatasetId -> conf.getDataset(snvDatasetId).read
)
Expand Down Expand Up @@ -85,7 +83,6 @@ case class Variants(rc: RuntimeETLContext, participantId: Column = col("particip
.withDbSNP(data(dbsnp.id))
.withClinvar(data(clinvar.id))
.withGenes(data(genes.id))
.withSpliceAi(data(spliceai.id))
.withCosmic(data(cosmic.id))
.withGeneExternalReference
.withVariantExternalReference
Expand Down Expand Up @@ -217,27 +214,6 @@ object Variants {
}
}

def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

val scores = spliceai.selectLocus($"symbol", $"max_score" as "spliceai")
.withColumn("type", when($"spliceai.ds" === 0, null).otherwise($"spliceai.type"))
.withColumn("spliceai", struct($"spliceai.ds" as "ds", $"type"))
.drop("type")

df
.select($"*", explode_outer($"genes") as "gene", $"gene.symbol" as "symbol") // explode_outer since genes can be null
.join(scores, locusColumnNames :+ "symbol", "left")
.drop("symbol") // only used for joining
.withColumn("gene", struct($"gene.*", $"spliceai")) // add spliceai struct as nested field of gene struct
.groupByLocus()
.agg(
first(struct(df.drop("genes")("*"))) as "variant",
collect_list("gene") as "genes" // re-create genes list for each locus, now containing spliceai struct
)
.select("variant.*", "genes")
}

def withCosmic(cosmic: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package bio.ferlab.datalake.spark3.publictables.enriched
import bio.ferlab.datalake.commons.config.{Coalesce, DatasetConf, RuntimeETLContext}
import bio.ferlab.datalake.spark3.etl.v4.SimpleSingleETL
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._
import bio.ferlab.datalake.spark3.implicits.GenomicImplicits._
import bio.ferlab.datalake.spark3.implicits.GenomicImplicits.columns.locusColumnNames
import bio.ferlab.datalake.spark3.implicits.SparkUtils.removeEmptyObjectsIn
import bio.ferlab.datalake.spark3.publictables.enriched.Genes._
import mainargs.{ParserForMethods, main}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}

import java.time.LocalDateTime

Expand All @@ -20,6 +23,7 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

override def extract(lastRunValue: LocalDateTime,
currentRunValue: LocalDateTime): Map[String, DataFrame] = {
Expand All @@ -30,7 +34,8 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
human_genes.id -> human_genes.read,
ddd_gene_set.id -> ddd_gene_set.read,
cosmic_gene_set.id -> cosmic_gene_set.read,
gnomad_constraint.id -> gnomad_constraint.read
gnomad_constraint.id -> gnomad_constraint.read,
spliceai.id -> spliceai.read
)
}

Expand All @@ -55,11 +60,41 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
.withDDD(data(ddd_gene_set.id))
.withCosmic(data(cosmic_gene_set.id))
.withGnomadConstraint(data(gnomad_constraint.id))
.withSpliceAi(data(spliceai.id))
}

override def defaultRepartition: DataFrame => DataFrame = Coalesce()
}

object Genes {
@main
def run(rc: RuntimeETLContext): Unit = {
Genes(rc).run()
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)

implicit class DataFrameOps(df: DataFrame) {

def joinAndMergeWith(gene_set: DataFrame,
joinOn: Seq[String],
asColumnName: String,
aggFirst: Boolean = false): DataFrame = {
val aggFn: Column => Column = c => if (aggFirst) first(c) else collect_list(c)
val aggDF = df
.join(gene_set, joinOn, "left")
.groupBy("symbol")
.agg(
first(struct(df("*"))) as "hg",
aggFn(struct(gene_set.drop(joinOn: _*)("*"))) as asColumnName,
)
.select(col("hg.*"), col(asColumnName))
if (aggFirst)
aggDF
else
aggDF.withColumn(asColumnName, removeEmptyObjectsIn(asColumnName))
}

def withGnomadConstraint(gnomad: DataFrame): DataFrame = {
val gnomadConstraint = gnomad
.groupBy("chromosome", "symbol")
Expand Down Expand Up @@ -105,34 +140,17 @@ case class Genes(rc: RuntimeETLContext) extends SimpleSingleETL(rc) {
df.joinAndMergeWith(hpoPrepared, Seq("entrez_gene_id"), "hpo")
}

def joinAndMergeWith(gene_set: DataFrame, joinOn: Seq[String], asColumnName: String, aggFirst: Boolean = false): DataFrame = {
val aggFn: Column => Column = c => if (aggFirst) first(c) else collect_list(c)
val aggDF = df
.join(gene_set, joinOn, "left")
.groupBy("symbol")
.agg(
first(struct(df("*"))) as "hg",
aggFn(struct(gene_set.drop(joinOn: _*)("*"))) as asColumnName,
)
.select(col("hg.*"), col(asColumnName))
if (aggFirst)
aggDF
else
aggDF.withColumn(asColumnName, removeEmptyObjectsIn(asColumnName))
}
def withSpliceAi(spliceai: DataFrame)(implicit spark: SparkSession): DataFrame = {
import spark.implicits._

}


override def defaultRepartition: DataFrame => DataFrame = Coalesce()
}
val spliceAiPrepared = spliceai
.select($"symbol", $"max_score.*")
.withColumn("type", when($"ds" === 0, null).otherwise($"type"))

object Genes {
@main
def run(rc: RuntimeETLContext): Unit = {
Genes(rc).run()
df
.joinAndMergeWith(spliceAiPrepared, Seq("symbol"), "spliceai", aggFirst = true)
.withColumn("spliceai", when($"spliceai.ds".isNull and $"spliceai.type".isNull, null).otherwise($"spliceai"))
}
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrThrow(args)
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package bio.ferlab.datalake.spark3.genomics.enriched
import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.genomics.enriched.Variants.DataFrameOps
import bio.ferlab.datalake.spark3.genomics.{FrequencySplit, SimpleAggregation}
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.enriched.EnrichedVariant.CMC
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedSpliceAi, EnrichedVariant, MAX_SCORE}
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, EnrichedVariant}
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.normalized.NormalizedCosmicMutationSet
import bio.ferlab.datalake.testutils.{SparkSpec, TestETLContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, collect_set, max}
Expand All @@ -26,7 +25,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val dbsnp: DatasetConf = conf.getDataset("normalized_dbsnp")
val clinvar: DatasetConf = conf.getDataset("normalized_clinvar")
val genes: DatasetConf = conf.getDataset("enriched_genes")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")
val cosmic: DatasetConf = conf.getDataset("normalized_cosmic_mutation_set")

val occurrencesDf: DataFrame = Seq(
Expand All @@ -42,7 +40,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
val dbsnpDf: DataFrame = Seq(NormalizedDbsnp()).toDF
val clinvarDf: DataFrame = Seq(NormalizedClinvar(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF
val genesDf: DataFrame = Seq(EnrichedGenes()).toDF()
val spliceaiDf: DataFrame = Seq(EnrichedSpliceAi(chromosome = "1", start = 69897, reference = "T", alternate = "C", symbol = "OR4F5", ds_ag = 0.01, `max_score` = MAX_SCORE(ds = 0.01, `type` = Seq("AG")))).toDF()
val cosmicDf: DataFrame = Seq(NormalizedCosmicMutationSet(chromosome = "1", start = 69897, reference = "T", alternate = "C")).toDF()

val etl = Variants(TestETLContext(), snvDatasetId = snvKeyId, splits = Seq(FrequencySplit("frequency", extraAggregations = Seq(SimpleAggregation(name = "zygosities", c = col("zygosity"))))))
Expand All @@ -57,7 +54,6 @@ class EnrichedVariantsSpec extends SparkSpec with WithTestConfig {
dbsnp.id -> dbsnpDf,
clinvar.id -> clinvarDf,
genes.id -> genesDf,
spliceai.id -> spliceaiDf,
cosmic.id -> cosmicDf
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package bio.ferlab.datalake.spark3.publictables.enriched

import bio.ferlab.datalake.commons.config.DatasetConf
import bio.ferlab.datalake.spark3.implicits.DatasetConfImplicits._
import bio.ferlab.datalake.testutils.models.enriched.{EnrichedGenes, OMIM, ORPHANET, COSMIC}
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.spark3.publictables.enriched.Genes._
import bio.ferlab.datalake.spark3.testutils.WithTestConfig
import bio.ferlab.datalake.testutils.models.normalized.NormalizedCosmicGeneSet
import bio.ferlab.datalake.testutils.models.enriched._
import bio.ferlab.datalake.testutils.models.normalized._
import bio.ferlab.datalake.testutils.{CleanUpBeforeAll, CreateDatabasesBeforeAll, SparkSpec, TestETLContext}
import org.apache.spark.sql.functions
import org.apache.spark.sql.functions.col
Expand All @@ -22,6 +22,7 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
val ddd_gene_set: DatasetConf = conf.getDataset("normalized_ddd_gene_set")
val cosmic_gene_set: DatasetConf = conf.getDataset("normalized_cosmic_gene_set")
val gnomad_constraint: DatasetConf = conf.getDataset("normalized_gnomad_constraint_v2_1_1")
val spliceai: DatasetConf = conf.getDataset("enriched_spliceai")

private val inputData = Map(
omim_gene_set.id -> Seq(
Expand All @@ -35,7 +36,8 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
gnomad_constraint.id -> Seq(
NormalizedGnomadConstraint(chromosome = "1", start = 69897, symbol = "OR4F5", `pLI` = 1.0f, oe_lof_upper = 0.01f),
NormalizedGnomadConstraint(chromosome = "1", start = 69900, symbol = "OR4F5", `pLI` = 0.9f, oe_lof_upper = 0.054f)
).toDF()
).toDF(),
spliceai.id -> Seq(EnrichedSpliceAi(`symbol` = "OR4F5")).toDF()
)

val job = new Genes(TestETLContext())
Expand Down Expand Up @@ -78,5 +80,36 @@ class GenesSpec extends SparkSpec with WithTestConfig with CreateDatabasesBefore
resultDF.where("symbol='OR4F5'").as[EnrichedGenes].collect().head shouldBe
EnrichedGenes(`orphanet` = expectedOrphanet, `omim` = expectedOmim, `cosmic` = expectedCosmic)
}

"withSpliceAi" should "enrich genes with SpliceAi scores" in {
val genes = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1"),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2"),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3"),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4"),
).toDF().drop("spliceai")

val spliceai = Seq(
// snv
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene1", `max_score` = MAX_SCORE(`ds` = 2.0, `type` = Seq("AL"))),
EnrichedSpliceAi(`chromosome` = "1", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "C", `symbol` = "gene2", `max_score` = MAX_SCORE(`ds` = 0.0, `type` = Seq("AG", "AL", "DG", "DL"))),

// indel
EnrichedSpliceAi(`chromosome` = "2", `start` = 1, `end` = 2, `reference` = "T", `alternate` = "AT", `symbol` = "gene3", `max_score` = MAX_SCORE(`ds` = 1.0, `type` = Seq("AG", "AL")))
).toDF()

val result = genes.withSpliceAi(spliceai)

val expected = Seq(
EnrichedGenes(`chromosome` = "1", `symbol` = "gene1", `spliceai` = Some(SPLICEAI(`ds` = 2.0, `type` = Some(List("AL"))))),
EnrichedGenes(`chromosome` = "1", `symbol` = "gene2", `spliceai` = Some(SPLICEAI(`ds` = 0.0, `type` = None))),
EnrichedGenes(`chromosome` = "2", `symbol` = "gene3", `spliceai` = Some(SPLICEAI(`ds` = 1.0, `type` = Some(List("AG", "AL"))))),
EnrichedGenes(`chromosome` = "3", `symbol` = "gene4", `spliceai` = None),
)

result
.as[EnrichedGenes]
.collect() should contain theSameElementsAs expected
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ case class EnrichedGenes(`symbol`: String = "OR4F5",
`chromosome`: String = "1",
`ddd`: List[DDD] = List(DDD()),
`cosmic`: List[COSMIC] = List(COSMIC()),
gnomad: Option[GNOMAD] = Some(GNOMAD()))
gnomad: Option[GNOMAD] = Some(GNOMAD()),
spliceai: Option[SPLICEAI] = Some(SPLICEAI()))

case class ORPHANET(`disorder_id`: Long = 17827,
`panel`: String = "Immunodeficiency due to a classical component pathway complement deficiency",
Expand All @@ -40,3 +41,6 @@ case class COSMIC(`tumour_types_germline`: List[String] = List("breast", "colon"

case class GNOMAD(pli: Float = 1.0f,
loeuf: Float = 0.054f)

case class SPLICEAI(ds: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ case class EnrichedSpliceAi(`chromosome`: String = "1",
`alternate`: String = "G",
`allele`: String = "G",
`symbol`: String = "KCNH1",
`ds_ag`: Double = 0.0,
`ds_al`: Double = 0.0,
`ds_dg`: Double = 0.0,
`ds_dl`: Double = 0.0,
`ds_ag`: Double = 0.1,
`ds_al`: Double = 0.1,
`ds_dg`: Double = 0.1,
`ds_dl`: Double = 0.1,
`dp_ag`: Int = -3,
`dp_al`: Int = 36,
`dp_dg`: Int = 32,
`dp_dl`: Int = 22,
`max_score`: MAX_SCORE = MAX_SCORE())

case class MAX_SCORE(`ds`: Double = 0.0,
case class MAX_SCORE(`ds`: Double = 0.1,
`type`: Seq[String] = Seq("AG", "AL", "DG", "DL"))
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ object EnrichedVariant {
)


case class SPLICEAI(ds: Double = 0.01,
`type`: Seq[String] = Seq("AG"))
case class SPLICEAI(ds: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))

case class CMC(mutation_url: String = "https://cancer.sanger.ac.uk/cosmic/mutation/overview?id=29491889&genome=37",
shared_aa: Int = 9,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ case class NormalizedSpliceAi(`chromosome`: String = "2",
`alternate`: String = "G",
`allele`: String = "G",
`symbol`: String = "KCNH1",
`ds_ag`: Double = 0.0,
`ds_al`: Double = 0.0,
`ds_dg`: Double = 0.0,
`ds_dl`: Double = 0.0,
`ds_ag`: Double = 0.1,
`ds_al`: Double = 0.1,
`ds_dg`: Double = 0.1,
`ds_dl`: Double = 0.1,
`dp_ag`: Int = -3,
`dp_al`: Int = 36,
`dp_dg`: Int = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ object PreparedVariantCentric {
`hpo_term_name`: String = "Hyperreflexia",
`hpo_term_label`: String = "Hyperreflexia (HP:0001347)")

case class SPLICEAI(`ds`: Double = 0.01,
`type`: Seq[String] = Seq("AG"))
case class SPLICEAI(`ds`: Double = 0.1,
`type`: Option[Seq[String]] = Some(Seq("AG", "AL", "DG", "DL")))

case class TOPMED_BRAVO(`ac`: Long = 2,
`an`: Long = 125568,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ case class RawSpliceAi(`contigName`: String = "1",
`qual`: Option[Double] = None,
`filters`: Option[Seq[String]] = None,
`splitFromMultiAllelic`: Boolean = false,
`INFO_SpliceAI`: Seq[String] = Seq("G|KCNH1|0.00|0.00|0.00|0.00|-3|36|32|22"),
`INFO_SpliceAI`: Seq[String] = Seq("G|KCNH1|0.1|0.1|0.1|0.1|-3|36|32|22"),
`INFO_OLD_MULTIALLELIC`: Option[String] = None,
`genotypes`: Seq[GENOTYPES] = Seq(GENOTYPES(sampleId = None)))

0 comments on commit 95bd2e2

Please sign in to comment.