Skip to content

Commit

Permalink
test: improve fixture definition by constraining dataset IDs (#954)
Browse files Browse the repository at this point in the history
* feat(study_index): add `StudyIndex.VALID_TYPES` enum

* test: mock data have more stringent ID definition + other improvements

* refactor: set ID specification in fixtures with expression to avoid changing nullability status

* refactor: set ID specification in fixtures with expression to avoid changing nullability status

* chore: fix issues
  • Loading branch information
ireneisdoomed authored Dec 19, 2024
1 parent cb64cc5 commit 5534908
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 26 deletions.
12 changes: 12 additions & 0 deletions src/gentropy/dataset/study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class StudyIndex(Dataset):
A study index dataset captures all the metadata for all studies including GWAS and Molecular QTL.
"""

VALID_TYPES = [
"gwas",
"eqtl",
"pqtl",
"sqtl",
"tuqtl",
"sceqtl",
"scpqtl",
"scsqtl",
"sctuqtl",
]

@staticmethod
def _aggregate_samples_by_ancestry(merged: Column, ancestry: Column) -> Column:
"""Aggregate sample counts by ancestry in a list of struct colmns.
Expand Down
149 changes: 123 additions & 26 deletions tests/gentropy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def mock_colocalisation(spark: SparkSession) -> Colocalisation:
randomSeedMethod="hash_fieldname",
)
.withSchema(coloc_schema)
.withColumnSpec(
"leftStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec("h0", percentNulls=0.1)
.withColumnSpec("h1", percentNulls=0.1)
.withColumnSpec("h2", percentNulls=0.1)
Expand Down Expand Up @@ -103,6 +111,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
randomSeedMethod="hash_fieldname",
)
.withSchema(si_schema)
.withColumnSpec(
"studyId",
expr="cast(id as string)",
)
.withColumnSpec(
"traitFromSourceMappedIds",
expr="array(cast(rand() AS string))",
Expand All @@ -123,7 +135,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
expr='array(named_struct("sampleSize", cast(rand() as string), "ancestry", cast(rand() as string)))',
percentNulls=0.1,
)
.withColumnSpec("geneId", percentNulls=0.1)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec("pubmedId", percentNulls=0.1)
.withColumnSpec("publicationFirstAuthor", percentNulls=0.1)
.withColumnSpec("publicationDate", percentNulls=0.1)
Expand All @@ -134,9 +149,7 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
.withColumnSpec("nControls", percentNulls=0.1)
.withColumnSpec("nSamples", percentNulls=0.1)
.withColumnSpec("summarystatsLocation", percentNulls=0.1)
.withColumnSpec(
"studyType", percentNulls=0.0, values=["eqtl", "pqtl", "sqtl", "gwas"]
)
.withColumnSpec("studyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES)
)
return data_spec.build()

Expand Down Expand Up @@ -164,12 +177,30 @@ def mock_study_locus_overlap(spark: SparkSession) -> StudyLocusOverlap:
"""Mock StudyLocusOverlap dataset."""
overlap_schema = StudyLocusOverlap.get_schema()

data_spec = dg.DataGenerator(
spark,
rows=400,
partitions=4,
randomSeedMethod="hash_fieldname",
).withSchema(overlap_schema)
data_spec = (
dg.DataGenerator(
spark,
rows=400,
partitions=4,
randomSeedMethod="hash_fieldname",
)
.withSchema(overlap_schema)
.withColumnSpec(
"leftStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"tagVariantId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES
)
)

return StudyLocusOverlap(_df=data_spec.build(), _schema=overlap_schema)

Expand All @@ -186,6 +217,10 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
randomSeedMethod="hash_fieldname",
)
.withSchema(sl_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec("chromosome", percentNulls=0.1)
.withColumnSpec("position", minValue=100, percentNulls=0.1)
.withColumnSpec("beta", percentNulls=0.1)
Expand All @@ -202,7 +237,7 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
.withColumnSpec("finemappingMethod", percentNulls=0.1)
.withColumnSpec(
"locus",
expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(rand() as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))',
expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(floor(rand() * 400) + 1 as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))',
percentNulls=0.1,
)
)
Expand Down Expand Up @@ -262,6 +297,10 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(vi_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec("mostSevereConsequenceId", percentNulls=0.1)
# Nested column handling workaround
# https://github.com/databrickslabs/dbldatagen/issues/135
Expand All @@ -275,7 +314,7 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
"assessment", cast(rand() as string),
"score", rand(),
"assessmentFlag", cast(rand() as string),
"targetId", cast(rand() as string),
"targetId", cast(floor(rand() * 400) + 1 as string),
"normalizedScore", cast(rand() as float)
)
)
Expand All @@ -298,19 +337,19 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
"uniprotAccessions", array(cast(rand() as string)),
"isEnsemblCanonical", cast(rand() as boolean),
"codons", cast(rand() as string),
"distanceFromTss", cast(rand() as long),
"distanceFromFootprint", cast(rand() as long),
"distanceFromTss", cast(floor(rand() * 500000) as long),
"distanceFromFootprint", cast(floor(rand() * 500000) as long),
"appris", cast(rand() as string),
"maneSelect", cast(rand() as string),
"targetId", cast(rand() as string),
"targetId", cast(floor(rand() * 400) + 1 as string),
"impact", cast(rand() as string),
"lofteePrediction", cast(rand() as string),
"siftPrediction", rand(),
"polyphenPrediction", rand(),
"consequenceScore", cast(rand() as float),
"transcriptIndex", cast(rand() as integer),
"transcriptId", cast(rand() as string),
"biotype", cast(rand() as string),
"biotype", 'protein_coding',
"approvedSymbol", cast(rand() as string)
)
)
Expand Down Expand Up @@ -355,13 +394,20 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame:
name="summaryStats",
)
.withSchema(ss_schema)
.withColumnSpec(
"studyId",
expr="cast(id as string)",
)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
# Allowing missingness in effect allele frequency and enforce upper limit:
.withColumnSpec(
"effectAlleleFrequencyFromSource", percentNulls=0.1, maxValue=1.0
)
# Allowing missingness:
.withColumnSpec("standardError", percentNulls=0.1)
# Making sure p-values are below 1:
)

return data_spec.build()
Expand Down Expand Up @@ -390,6 +436,10 @@ def mock_ld_index(spark: SparkSession) -> LDIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(ld_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec(
"ldSet",
expr="array(named_struct('tagVariantId', cast(rand() as string), 'rValues', array(named_struct('population', cast(rand() as string), 'r', cast(rand() as double)))))",
Expand Down Expand Up @@ -526,18 +576,24 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
data_spec = (
dg.DataGenerator(
spark,
rows=400,
rows=30,
partitions=4,
randomSeedMethod="hash_fieldname",
)
.withSchema(gi_schema)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec("approvedSymbol", percentNulls=0.1)
.withColumnSpec("biotype", percentNulls=0.1)
.withColumnSpec(
"biotype", percentNulls=0.1, values=["protein_coding", "lncRNA"]
)
.withColumnSpec("approvedName", percentNulls=0.1)
.withColumnSpec("tss", percentNulls=0.1)
.withColumnSpec("start", percentNulls=0.1)
.withColumnSpec("end", percentNulls=0.1)
.withColumnSpec("strand", percentNulls=0.1)
.withColumnSpec("strand", percentNulls=0.1, values=[1, -1])
)

return GeneIndex(_df=data_spec.build(), _schema=gi_schema)
Expand All @@ -559,6 +615,10 @@ def mock_biosample_index(spark: SparkSession) -> BiosampleIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(bi_schema)
.withColumnSpec(
"biosampleId",
expr="cast(id as string)",
)
.withColumnSpec("biosampleName", percentNulls=0.1)
.withColumnSpec("description", percentNulls=0.1)
.withColumnSpec("xrefs", expr=array_expression, percentNulls=0.1)
Expand Down Expand Up @@ -613,9 +673,35 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix:
def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
"""Mock l2g gold standard dataset."""
schema = L2GGoldStandard.get_schema()
data_spec = dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
).withSchema(schema)
data_spec = (
dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
)
.withSchema(schema)
.withColumnSpec(
"studyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec(
"traitFromSourceMappedId",
expr="cast(id as string)",
)
.withColumnSpec(
"goldStandardSet",
values=[
L2GGoldStandard.GS_NEGATIVE_LABEL,
L2GGoldStandard.GS_POSITIVE_LABEL,
],
)
)

return L2GGoldStandard(_df=data_spec.build(), _schema=schema)

Expand All @@ -624,9 +710,20 @@ def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
def mock_l2g_predictions(spark: SparkSession) -> L2GPrediction:
"""Mock l2g predictions dataset."""
schema = L2GPrediction.get_schema()
data_spec = dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
).withSchema(schema)
data_spec = (
dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
)
.withSchema(schema)
.withColumnSpec(
"studyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
)

return L2GPrediction(_df=data_spec.build(), _schema=schema)

Expand Down

0 comments on commit 5534908

Please sign in to comment.