diff --git a/savana/helper.py b/savana/helper.py index 4c653d3..9d9162e 100755 --- a/savana/helper.py +++ b/savana/helper.py @@ -14,7 +14,7 @@ from time import time from datetime import datetime -__version__ = "1.0.4" +__version__ = "1.0.5" samflag_desc_to_number = { "BAM_CMATCH": 0, # M diff --git a/savana/train.py b/savana/train.py index 31bad12..fd0bba3 100755 --- a/savana/train.py +++ b/savana/train.py @@ -53,8 +53,12 @@ def format_data(data_matrix): # convert the SVTYPE to 0/1 data_matrix['SVTYPE'] = data_matrix['SVTYPE'].map({'BND':0,'INS':1}) - # ONE-HOT ENCODING + # one-hot-encoding of BP_NOTATION sv_type_one_hot = pd.get_dummies(data_matrix['BP_NOTATION']) + # check to make sure all bp types are present + for bp_type in ["++","+-","-+","--"]: + if bp_type not in sv_type_one_hot: + sv_type_one_hot[bp_type] = False data_matrix.drop('BP_NOTATION', axis=1) data_matrix = data_matrix.join(sv_type_one_hot)