Skip to content

Commit

Permalink
add tax index max logic vs taxindex size to avoid out of bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
cactuskid committed Aug 30, 2024
1 parent fb1556e commit f1246cb
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 14 deletions.
15 changes: 8 additions & 7 deletions src/HogProf/lshbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(self,h5_oma=None,fileglob = None, taxa=None,masterTree=None, saving
else:
raise Exception( 'please specify an output location' )
self.errorfile = self.saving_path + 'errors.txt'

if masterTree is None:
if h5_oma:
genomes = pd.DataFrame(h5_oma.root.Genome.read())["NCBITaxonId"].tolist()
Expand Down Expand Up @@ -120,11 +119,11 @@ def __init__(self,h5_oma=None,fileglob = None, taxa=None,masterTree=None, saving
self.tree_string = treein.read()
#self.tree_string = self.tree_ete3.write(format=0)
else:
raise Exception( 'please specify a tree' )
raise Exception( 'please specify a tree in either phylo xml or nwk format' )

if self.reformat_names:
self.tree_ete3, self.idmapper = pyhamutils.tree2numerical(self.tree_ete3)
with open( self.saving_path + 'reformatted_tree.nw', 'w') as treeout:
with open( self.saving_path + 'reformatted_tree.nwk', 'w') as treeout:
treeout.write(self.tree_ete3.write(format=0 ))
with open( self.saving_path + 'idmapper.pkl', 'wb') as idout:
idout.write( pickle.dumps(self.idmapper))
Expand Down Expand Up @@ -175,15 +174,17 @@ def __init__(self,h5_oma=None,fileglob = None, taxa=None,masterTree=None, saving

self.HASH_PIPELINE = functools.partial( hashutils.row2hash , taxaIndex=self.taxaIndex, treeweights=self.treeweights, wmg=wmg , lossonly = lossonly, duplonly = duplonly)
if self.h5OMA:

self.READ_ORTHO = functools.partial(pyhamutils.get_orthoxml_oma, db_obj=self.db_obj)

if self.h5OMA:
self.n_groups = len(self.h5OMA.root.OrthoXML.Index)
print( 'reading oma hdf5 with n groups:', self.n_groups)
elif self.fileglob:
print('reading orthoxml files:' , len(self.fileglob))
self.n_groups = len(self.fileglob)
else:
raise Exception( 'please specify an input file' )

self.hashes_path = self.saving_path + 'hashes.h5'
self.lshpath = self.saving_path + 'newlsh.pkl'
self.lshforestpath = self.saving_path + 'newlshforest.pkl'
Expand Down Expand Up @@ -258,7 +259,6 @@ def worker(self, i, q, retq, matq, l):
while True:
df = q.get()
if df is not None :

try:
df['tree'] = df[['Fam', 'ortho']].apply(self.HAM_PIPELINE, axis=1)
df[['hash','rows']] = df[['Fam', 'tree']].apply(self.HASH_PIPELINE, axis=1)
Expand Down Expand Up @@ -298,6 +298,7 @@ def saver(self, i, q, retq, matq, l ):
with open(self.errorfile, 'w') as hashes_error_files:
with h5py.File(self.hashes_path, 'w', libver='latest') as h5hashes:
datasets = {}

if taxstr not in h5hashes.keys():
if self.verbose == True:
print('creating dataset')
Expand Down Expand Up @@ -325,8 +326,7 @@ def saver(self, i, q, retq, matq, l ):
if savedf is None:
savedf = this_dataframe[['Fam', 'ortho']]
else:
savedf = pd.concat( [ savedf , this_dataframe[['Fam', 'ortho']] ] )

savedf = pd.concat( [ savedf , this_dataframe[['Fam', 'ortho']] ] )
if t.time() - save_start > 200:
print( 'saving at :' , t.time() - global_time )
forest.index()
Expand All @@ -341,6 +341,7 @@ def saver(self, i, q, retq, matq, l ):
#save the mapping of fam to orthoxml
print('saving orthoxml to fam mapping')
print(savedf)

savedf.to_csv(self.saving_path + 'fam2orthoxml.csv')
save_start = t.time()
else:
Expand Down
2 changes: 0 additions & 2 deletions src/HogProf/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self,lshforestpath = None, hashes_h5=None, mat_path= None, oma = Fa
self.tax_mask = taxmask
self.tree_string = self.tree.write(format=1)


self.taxaIndex, self.ReverseTaxaIndex = files_utils.generate_taxa_index(self.tree)
self.treeweights = hashutils.generate_treeweights(self.tree , self.taxaIndex , None, None )
self.swap2taxcode = swap2taxcode
Expand Down Expand Up @@ -176,7 +175,6 @@ def return_profile_complements(self, fam):
losses = set([ n.name for n in tp.traverse() if n.lost and n.name in self.taxaIndex ])
#these are the roots of the fams we are looking for
#we just assume no duplications or losses from this point

ancestral_nodes = ([ n for n in profiler.tree.traverse() if n.name in losses])
losses=[]
dupl=[]
Expand Down
5 changes: 3 additions & 2 deletions src/HogProf/utils/hashutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def hash_tree(tp , taxaIndex , treeweights , wmg , lossonly = False , duplonly =
if not tp:
return None, None

hog_matrix_weighted = np.zeros((1, 3*len(taxaIndex)))
hog_matrix_binary = np.zeros((1, 3*len(taxaIndex)))
taxaIndex_max = max(taxaIndex.values())+1
hog_matrix_weighted = np.zeros((1, 3*len(taxaIndex_max)))
hog_matrix_binary = np.zeros((1, 3*len(taxaIndex_max)))
if tp:
losses = [ taxaIndex[n.name] for n in tp.traverse() if n.lost and n.name in taxaIndex ]
dupl = [ taxaIndex[n.name] for n in tp.traverse() if n.dupl and n.name in taxaIndex ]
Expand Down
5 changes: 2 additions & 3 deletions src/HogProf/utils/pyhamutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def get_orthoxml_tar(fam, tar):
else:
raise Exception( member + ' : not found in tarfile ')
return orthoxml

def get_species_from_orthoxml(orthoxml):
NCBI_taxid2name = {}
root = ET.fromstring(orthoxml)
for child in root:
if 'species' in child.tag:
NCBI_taxid2name[child.attrib['NCBITaxId']] = child.attrib['name']
return NCBI_taxid2name

def switch_name_ncbi_id(orthoxml , mapdict = None ):
#swap ncbi taxid for species name to avoid ambiguity
#mapdict should be a mapping from species name to taxid if the info isnt in the orthoxmls
Expand All @@ -38,7 +40,6 @@ def switch_name_ncbi_id(orthoxml , mapdict = None ):
def reformat_treenames( tree , mapdict = None ):
#tree is an ete3 tree instance
#replace ( ) - / . and spaces with underscores

#iterate over all nodes
for node in tree.traverse():
if mapdict:
Expand Down Expand Up @@ -83,8 +84,6 @@ def orthoxml2numerical(orthoxml , mapper):
orthoxml = ET.tostring(root, encoding='unicode', method='xml')
return orthoxml



def get_ham_treemap_from_row(row, tree , level = None , swap_ids = True , orthoXML_as_string = True , use_phyloxml = False , use_internal_name = True ,reformat_names= False, orthomapper = None ):
fam, orthoxml = row
format = 'newick_string'
Expand Down

0 comments on commit f1246cb

Please sign in to comment.