Skip to content

Commit

Permalink
use enum
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson committed Feb 6, 2025
1 parent bb1baa8 commit 5e09a3e
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions src/ga4gh/vrs/extras/annotator/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ class SeqRepoProxyType(str, Enum):
REST = "rest"


# Field names for VCF
VRS_ALLELE_IDS_FIELD = "VRS_Allele_IDs"
VRS_STARTS_FIELD = "VRS_Starts"
VRS_ENDS_FIELD = "VRS_Ends"
VRS_STATES_FIELD = "VRS_States"
VRS_ERROR_FIELD = "VRS_Error"
class FieldName(str, Enum):
"""Define VCF field names for VRS annotations"""

IDS_FIELD = "VRS_Allele_IDs"
STARTS_FIELD = "VRS_Starts"
ENDS_FIELD = "VRS_Ends"
STATES_FIELD = "VRS_States"
ERROR_FIELD = "VRS_Error"


# VCF character escape map
VCF_ESCAPE_MAP = str.maketrans(
Expand Down Expand Up @@ -117,7 +120,7 @@ def annotate(

vcf = pysam.VariantFile(filename=str(input_vcf_path.absolute()))
vcf.header.info.add(
VRS_ALLELE_IDS_FIELD,
FieldName.IDS_FIELD.value,
info_field_num,
"String",
(
Expand All @@ -126,15 +129,15 @@ def annotate(
),
)
vcf.header.info.add(
VRS_ERROR_FIELD,
FieldName.ERROR_FIELD.value,
".",
"String",
("If an error occurred computing a VRS Identifier, the error message"),
)

if vrs_attributes:
vcf.header.info.add(
VRS_STARTS_FIELD,
FieldName.STARTS_FIELD.value,
info_field_num,
"String",
(
Expand All @@ -143,7 +146,7 @@ def annotate(
),
)
vcf.header.info.add(
VRS_ENDS_FIELD,
FieldName.ENDS_FIELD.value,
info_field_num,
"String",
(
Expand All @@ -152,7 +155,7 @@ def annotate(
),
)
vcf.header.info.add(
VRS_STATES_FIELD,
FieldName.STATES_FIELD.value,
info_field_num,
"String",
(
Expand All @@ -171,12 +174,12 @@ def annotate(
vrs_data = {} if output_pkl_path else None
for record in vcf:
if vcf_out:
additional_info_fields = [VRS_ALLELE_IDS_FIELD]
additional_info_fields = [FieldName.IDS_FIELD]
if vrs_attributes:
additional_info_fields += [
VRS_STARTS_FIELD,
VRS_ENDS_FIELD,
VRS_STATES_FIELD,
FieldName.STARTS_FIELD,
FieldName.ENDS_FIELD,
FieldName.STATES_FIELD,
]
else:
# no INFO field names need to be designated if not producing an annotated VCF
Expand All @@ -195,8 +198,8 @@ def annotate(
_logger.exception("VRS error on %s-%s", record.chrom, record.pos)
err_msg = f"{ex}" or f"{type(ex)}"
err_msg = err_msg.translate(VCF_ESCAPE_MAP)
additional_info_fields = [VRS_ERROR_FIELD]
vrs_field_data = {VRS_ERROR_FIELD: [err_msg]}
additional_info_fields = [FieldName.ERROR_FIELD]
vrs_field_data = {FieldName.ERROR_FIELD.value: [err_msg]}

_logger.debug(
"VCF record %s-%s generated vrs_field_data %s",
Expand All @@ -207,7 +210,9 @@ def annotate(

if output_vcf_path and vcf_out:
for k in additional_info_fields:
record.info[k] = [value or "." for value in vrs_field_data[k]]
record.info[k.value] = [
value or "." for value in vrs_field_data[k.value]
]
vcf_out.write(record)

vcf.close()
Expand Down Expand Up @@ -287,7 +292,7 @@ def _get_vrs_object(

if vrs_field_data:
allele_id = vrs_obj.id if vrs_obj else ""
vrs_field_data[VRS_ALLELE_IDS_FIELD].append(allele_id)
vrs_field_data[FieldName.IDS_FIELD].append(allele_id)

if vrs_attributes:
if vrs_obj:
Expand All @@ -301,16 +306,16 @@ def _get_vrs_object(
else:
start = end = alt = ""

vrs_field_data[VRS_STARTS_FIELD].append(start)
vrs_field_data[VRS_ENDS_FIELD].append(end)
vrs_field_data[VRS_STATES_FIELD].append(alt)
vrs_field_data[FieldName.STARTS_FIELD].append(start)
vrs_field_data[FieldName.ENDS_FIELD].append(end)
vrs_field_data[FieldName.STATES_FIELD].append(alt)

def _get_vrs_data(
self,
record: pysam.VariantRecord,
vrs_data: dict | None,
assembly: str,
additional_info_fields: list[str],
additional_info_fields: list[FieldName],
vrs_attributes: bool = False,
compute_for_ref: bool = True,
require_validation: bool = True,
Expand All @@ -333,7 +338,7 @@ def _get_vrs_data(
:return: A dictionary mapping VRS-related INFO fields to lists of associated
values. Will be empty if `create_annotated_vcf` is false.
"""
vrs_field_data = {field: [] for field in additional_info_fields}
vrs_field_data = {field.value: [] for field in additional_info_fields}

# Get VRS data for reference allele
gnomad_loc = f"{record.chrom}-{record.pos}"
Expand All @@ -356,7 +361,7 @@ def _get_vrs_data(
if "*" in allele:
_logger.debug("Star allele found: %s", allele)
for field in additional_info_fields:
vrs_field_data[field].append("")
vrs_field_data[field.value].append("")
else:
self._get_vrs_object(
allele,
Expand Down

0 comments on commit 5e09a3e

Please sign in to comment.