Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: improve fixture definition by constraining dataset IDs #954

Merged
merged 8 commits into from
Dec 19, 2024
12 changes: 12 additions & 0 deletions src/gentropy/dataset/study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class StudyIndex(Dataset):
A study index dataset captures all the metadata for all studies including GWAS and Molecular QTL.
"""

VALID_TYPES = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Care to say that this should be an enum (in the future)

"gwas",
"eqtl",
"pqtl",
"sqtl",
"tuqtl",
"sceqtl",
"scpqtl",
"scsqtl",
"sctuqtl",
]

@staticmethod
def _aggregate_samples_by_ancestry(merged: Column, ancestry: Column) -> Column:
"""Aggregate sample counts by ancestry in a list of struct colmns.
Expand Down
149 changes: 123 additions & 26 deletions tests/gentropy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def mock_colocalisation(spark: SparkSession) -> Colocalisation:
randomSeedMethod="hash_fieldname",
)
.withSchema(coloc_schema)
.withColumnSpec(
project-defiant marked this conversation as resolved.
Show resolved Hide resolved
"leftStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec("h0", percentNulls=0.1)
.withColumnSpec("h1", percentNulls=0.1)
.withColumnSpec("h2", percentNulls=0.1)
Expand Down Expand Up @@ -103,6 +111,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
randomSeedMethod="hash_fieldname",
)
.withSchema(si_schema)
.withColumnSpec(
"studyId",
expr="cast(id as string)",
)
.withColumnSpec(
"traitFromSourceMappedIds",
expr="array(cast(rand() AS string))",
Expand All @@ -123,7 +135,10 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
expr='array(named_struct("sampleSize", cast(rand() as string), "ancestry", cast(rand() as string)))',
percentNulls=0.1,
)
.withColumnSpec("geneId", percentNulls=0.1)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec("pubmedId", percentNulls=0.1)
.withColumnSpec("publicationFirstAuthor", percentNulls=0.1)
.withColumnSpec("publicationDate", percentNulls=0.1)
Expand All @@ -134,9 +149,7 @@ def mock_study_index_data(spark: SparkSession) -> DataFrame:
.withColumnSpec("nControls", percentNulls=0.1)
.withColumnSpec("nSamples", percentNulls=0.1)
.withColumnSpec("summarystatsLocation", percentNulls=0.1)
.withColumnSpec(
"studyType", percentNulls=0.0, values=["eqtl", "pqtl", "sqtl", "gwas"]
)
.withColumnSpec("studyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES)
)
return data_spec.build()

Expand Down Expand Up @@ -164,12 +177,30 @@ def mock_study_locus_overlap(spark: SparkSession) -> StudyLocusOverlap:
"""Mock StudyLocusOverlap dataset."""
overlap_schema = StudyLocusOverlap.get_schema()

data_spec = dg.DataGenerator(
spark,
rows=400,
partitions=4,
randomSeedMethod="hash_fieldname",
).withSchema(overlap_schema)
data_spec = (
dg.DataGenerator(
spark,
rows=400,
partitions=4,
randomSeedMethod="hash_fieldname",
)
.withSchema(overlap_schema)
.withColumnSpec(
"leftStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"tagVariantId",
expr="cast(id as string)",
)
.withColumnSpec(
"rightStudyType", percentNulls=0.0, values=StudyIndex.VALID_TYPES
)
)

return StudyLocusOverlap(_df=data_spec.build(), _schema=overlap_schema)

Expand All @@ -186,6 +217,10 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
randomSeedMethod="hash_fieldname",
)
.withSchema(sl_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec("chromosome", percentNulls=0.1)
.withColumnSpec("position", minValue=100, percentNulls=0.1)
.withColumnSpec("beta", percentNulls=0.1)
Expand All @@ -202,7 +237,7 @@ def mock_study_locus_data(spark: SparkSession) -> DataFrame:
.withColumnSpec("finemappingMethod", percentNulls=0.1)
.withColumnSpec(
"locus",
expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(rand() as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))',
expr='array(named_struct("is95CredibleSet", cast(rand() > 0.5 as boolean), "is99CredibleSet", cast(rand() > 0.5 as boolean), "logBF", rand(), "posteriorProbability", rand(), "variantId", cast(floor(rand() * 400) + 1 as string), "beta", rand(), "standardError", rand(), "r2Overall", rand(), "pValueMantissa", rand(), "pValueExponent", rand()))',
percentNulls=0.1,
)
)
Expand Down Expand Up @@ -262,6 +297,10 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(vi_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec("mostSevereConsequenceId", percentNulls=0.1)
# Nested column handling workaround
# https://github.com/databrickslabs/dbldatagen/issues/135
Expand All @@ -275,7 +314,7 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
"assessment", cast(rand() as string),
"score", rand(),
"assessmentFlag", cast(rand() as string),
"targetId", cast(rand() as string),
"targetId", cast(floor(rand() * 400) + 1 as string),
"normalizedScore", cast(rand() as float)
)
)
Expand All @@ -298,19 +337,19 @@ def mock_variant_index(spark: SparkSession) -> VariantIndex:
"uniprotAccessions", array(cast(rand() as string)),
"isEnsemblCanonical", cast(rand() as boolean),
"codons", cast(rand() as string),
"distanceFromTss", cast(rand() as long),
"distanceFromFootprint", cast(rand() as long),
"distanceFromTss", cast(floor(rand() * 500000) as long),
"distanceFromFootprint", cast(floor(rand() * 500000) as long),
"appris", cast(rand() as string),
"maneSelect", cast(rand() as string),
"targetId", cast(rand() as string),
"targetId", cast(floor(rand() * 400) + 1 as string),
"impact", cast(rand() as string),
"lofteePrediction", cast(rand() as string),
"siftPrediction", rand(),
"polyphenPrediction", rand(),
"consequenceScore", cast(rand() as float),
"transcriptIndex", cast(rand() as integer),
"transcriptId", cast(rand() as string),
"biotype", cast(rand() as string),
"biotype", 'protein_coding',
"approvedSymbol", cast(rand() as string)
)
)
Expand Down Expand Up @@ -355,13 +394,20 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame:
name="summaryStats",
)
.withSchema(ss_schema)
.withColumnSpec(
"studyId",
expr="cast(id as string)",
)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
# Allowing missingness in effect allele frequency and enforce upper limit:
.withColumnSpec(
"effectAlleleFrequencyFromSource", percentNulls=0.1, maxValue=1.0
)
# Allowing missingness:
.withColumnSpec("standardError", percentNulls=0.1)
# Making sure p-values are below 1:
)

return data_spec.build()
Expand Down Expand Up @@ -390,6 +436,10 @@ def mock_ld_index(spark: SparkSession) -> LDIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(ld_schema)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec(
"ldSet",
expr="array(named_struct('tagVariantId', cast(rand() as string), 'rValues', array(named_struct('population', cast(rand() as string), 'r', cast(rand() as double)))))",
Expand Down Expand Up @@ -526,18 +576,24 @@ def mock_gene_index(spark: SparkSession) -> GeneIndex:
data_spec = (
dg.DataGenerator(
spark,
rows=400,
rows=30,
partitions=4,
randomSeedMethod="hash_fieldname",
)
.withSchema(gi_schema)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec("approvedSymbol", percentNulls=0.1)
.withColumnSpec("biotype", percentNulls=0.1)
.withColumnSpec(
"biotype", percentNulls=0.1, values=["protein_coding", "lncRNA"]
)
.withColumnSpec("approvedName", percentNulls=0.1)
.withColumnSpec("tss", percentNulls=0.1)
.withColumnSpec("start", percentNulls=0.1)
.withColumnSpec("end", percentNulls=0.1)
.withColumnSpec("strand", percentNulls=0.1)
.withColumnSpec("strand", percentNulls=0.1, values=[1, -1])
)

return GeneIndex(_df=data_spec.build(), _schema=gi_schema)
Expand All @@ -559,6 +615,10 @@ def mock_biosample_index(spark: SparkSession) -> BiosampleIndex:
randomSeedMethod="hash_fieldname",
)
.withSchema(bi_schema)
.withColumnSpec(
"biosampleId",
expr="cast(id as string)",
)
.withColumnSpec("biosampleName", percentNulls=0.1)
.withColumnSpec("description", percentNulls=0.1)
.withColumnSpec("xrefs", expr=array_expression, percentNulls=0.1)
Expand Down Expand Up @@ -613,9 +673,35 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix:
def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
"""Mock l2g gold standard dataset."""
schema = L2GGoldStandard.get_schema()
data_spec = dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
).withSchema(schema)
data_spec = (
dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
)
.withSchema(schema)
.withColumnSpec(
"studyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"variantId",
expr="cast(id as string)",
)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
.withColumnSpec(
"traitFromSourceMappedId",
expr="cast(id as string)",
)
.withColumnSpec(
"goldStandardSet",
values=[
L2GGoldStandard.GS_NEGATIVE_LABEL,
L2GGoldStandard.GS_POSITIVE_LABEL,
],
)
)

return L2GGoldStandard(_df=data_spec.build(), _schema=schema)

Expand All @@ -624,9 +710,20 @@ def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard:
def mock_l2g_predictions(spark: SparkSession) -> L2GPrediction:
"""Mock l2g predictions dataset."""
schema = L2GPrediction.get_schema()
data_spec = dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
).withSchema(schema)
data_spec = (
dg.DataGenerator(
spark, rows=400, partitions=4, randomSeedMethod="hash_fieldname"
)
.withSchema(schema)
.withColumnSpec(
"studyLocusId",
expr="cast(id as string)",
)
.withColumnSpec(
"geneId",
expr="cast(id as string)",
)
)

return L2GPrediction(_df=data_spec.build(), _schema=schema)

Expand Down
Loading