diff --git a/src/ga4gh/vrs/extras/annotator/vcf.py b/src/ga4gh/vrs/extras/annotator/vcf.py index 3d0d424f..ed5e4a92 100644 --- a/src/ga4gh/vrs/extras/annotator/vcf.py +++ b/src/ga4gh/vrs/extras/annotator/vcf.py @@ -34,12 +34,15 @@ class SeqRepoProxyType(str, Enum): REST = "rest" -# Field names for VCF -VRS_ALLELE_IDS_FIELD = "VRS_Allele_IDs" -VRS_STARTS_FIELD = "VRS_Starts" -VRS_ENDS_FIELD = "VRS_Ends" -VRS_STATES_FIELD = "VRS_States" -VRS_ERROR_FIELD = "VRS_Error" +class FieldName(str, Enum): + """Define VCF field names for VRS annotations""" + + IDS_FIELD = "VRS_Allele_IDs" + STARTS_FIELD = "VRS_Starts" + ENDS_FIELD = "VRS_Ends" + STATES_FIELD = "VRS_States" + ERROR_FIELD = "VRS_Error" + # VCF character escape map VCF_ESCAPE_MAP = str.maketrans( @@ -117,7 +120,7 @@ def annotate( vcf = pysam.VariantFile(filename=str(input_vcf_path.absolute())) vcf.header.info.add( - VRS_ALLELE_IDS_FIELD, + FieldName.IDS_FIELD.value, info_field_num, "String", ( @@ -126,7 +129,7 @@ def annotate( ), ) vcf.header.info.add( - VRS_ERROR_FIELD, + FieldName.ERROR_FIELD.value, ".", "String", ("If an error occurred computing a VRS Identifier, the error message"), @@ -134,7 +137,7 @@ def annotate( if vrs_attributes: vcf.header.info.add( - VRS_STARTS_FIELD, + FieldName.STARTS_FIELD.value, info_field_num, "String", ( @@ -143,7 +146,7 @@ def annotate( ), ) vcf.header.info.add( - VRS_ENDS_FIELD, + FieldName.ENDS_FIELD.value, info_field_num, "String", ( @@ -152,7 +155,7 @@ def annotate( ), ) vcf.header.info.add( - VRS_STATES_FIELD, + FieldName.STATES_FIELD.value, info_field_num, "String", ( @@ -171,12 +174,12 @@ def annotate( vrs_data = {} if output_pkl_path else None for record in vcf: if vcf_out: - additional_info_fields = [VRS_ALLELE_IDS_FIELD] + additional_info_fields = [FieldName.IDS_FIELD] if vrs_attributes: additional_info_fields += [ - VRS_STARTS_FIELD, - VRS_ENDS_FIELD, - VRS_STATES_FIELD, + FieldName.STARTS_FIELD, + FieldName.ENDS_FIELD, + FieldName.STATES_FIELD, ] else: # no INFO field names need to be designated if not producing an annotated VCF @@ -195,8 +198,8 @@ def annotate( _logger.exception("VRS error on %s-%s", record.chrom, record.pos) err_msg = f"{ex}" or f"{type(ex)}" err_msg = err_msg.translate(VCF_ESCAPE_MAP) - additional_info_fields = [VRS_ERROR_FIELD] - vrs_field_data = {VRS_ERROR_FIELD: [err_msg]} + additional_info_fields = [FieldName.ERROR_FIELD] + vrs_field_data = {FieldName.ERROR_FIELD.value: [err_msg]} _logger.debug( "VCF record %s-%s generated vrs_field_data %s", @@ -207,7 +210,9 @@ def annotate( if output_vcf_path and vcf_out: for k in additional_info_fields: - record.info[k] = [value or "." for value in vrs_field_data[k]] + record.info[k.value] = [ + value or "." for value in vrs_field_data[k.value] + ] vcf_out.write(record) vcf.close() @@ -287,7 +292,7 @@ def _get_vrs_object( if vrs_field_data: allele_id = vrs_obj.id if vrs_obj else "" - vrs_field_data[VRS_ALLELE_IDS_FIELD].append(allele_id) + vrs_field_data[FieldName.IDS_FIELD].append(allele_id) if vrs_attributes: if vrs_obj: @@ -301,16 +306,16 @@ def _get_vrs_object( else: start = end = alt = "" - vrs_field_data[VRS_STARTS_FIELD].append(start) - vrs_field_data[VRS_ENDS_FIELD].append(end) - vrs_field_data[VRS_STATES_FIELD].append(alt) + vrs_field_data[FieldName.STARTS_FIELD].append(start) + vrs_field_data[FieldName.ENDS_FIELD].append(end) + vrs_field_data[FieldName.STATES_FIELD].append(alt) def _get_vrs_data( self, record: pysam.VariantRecord, vrs_data: dict | None, assembly: str, - additional_info_fields: list[str], + additional_info_fields: list[FieldName], vrs_attributes: bool = False, compute_for_ref: bool = True, require_validation: bool = True, @@ -333,7 +338,7 @@ def _get_vrs_data( :return: A dictionary mapping VRS-related INFO fields to lists of associated values. Will be empty if `create_annotated_vcf` is false. """ - vrs_field_data = {field: [] for field in additional_info_fields} + vrs_field_data = {field.value: [] for field in additional_info_fields} # Get VRS data for reference allele gnomad_loc = f"{record.chrom}-{record.pos}" @@ -356,7 +361,7 @@ def _get_vrs_data( if "*" in allele: _logger.debug("Star allele found: %s", allele) for field in additional_info_fields: - vrs_field_data[field].append("") + vrs_field_data[field.value].append("") else: self._get_vrs_object( allele,