Skip to content

Commit

Permalink
removing old code and outdate bam file references
Browse files Browse the repository at this point in the history
  • Loading branch information
helrick committed Aug 4, 2023
1 parent c3f69c5 commit 041435b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 257 deletions.
89 changes: 8 additions & 81 deletions savana/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def count_num_labels(source_breakpoints):
return label_counts

def get_potential_breakpoints(aln_filename, args, label, contig_order, chrom=None, start=None, end=None):
""" iterate through bam file, tracking potential breakpoints and saving relevant reads to fastq """
""" iterate through alignment file, tracking potential breakpoints and saving relevant reads to fastq """
potential_breakpoints = {}
if args.is_cram:
aln_file = pysam.AlignmentFile(aln_filename, "rc", reference_filename=args.ref)
Expand Down Expand Up @@ -208,16 +208,16 @@ def add_local_depth(intervals, aln_filenames, is_cram, ref):
start = max(int(intervals[0][1])-1, 0) # first start
end = int(intervals[-1][2]) # last end
read_stats = {}
for bam_type, aln_filename in aln_filenames.items():
for file_type, aln_filename in aln_filenames.items():
if is_cram:
aln_file = pysam.AlignmentFile(aln_filename, "rc", reference_filename=ref)
else:
aln_file = pysam.AlignmentFile(aln_filename, "rb")
read_stats[bam_type] = []
read_stats[file_type] = []
for read in aln_file.fetch(chrom, start, end):
if read.mapping_quality == 0 or read.is_duplicate:
continue
read_stats[bam_type].append([int(read.reference_start), int(read.reference_end), read.query_name])
read_stats[file_type].append([int(read.reference_start), int(read.reference_end), read.query_name])
del read
aln_file.close()
del aln_file
Expand All @@ -226,16 +226,10 @@ def add_local_depth(intervals, aln_filenames, is_cram, ref):
interval_start = int(i[1])
interval_end = int(i[2])
edge = int(i[4])
for bam_type, reads in read_stats.items():
#subtraction = [((interval_end-r[0])-r[1]) for r in reads]
#dp = sum(1 for x in subtraction if x <= 0)
#del subtraction
#comparison = [(interval_start <= r[1]) and (r[0] <= interval_end) for r in reads]

for file_type, reads in read_stats.items():
comparison = [[(interval_start - r[1]), (r[0] - interval_end)] for r in reads]
dp = sum(1 for x,y in comparison if x <= 0 and y <= 0)
del comparison

# for some reason, these methods aren't faster
'''
comparison = [[(interval_start <= r[1]), (r[0] <= interval_end)] for r in reads]
Expand All @@ -244,81 +238,14 @@ def add_local_depth(intervals, aln_filenames, is_cram, ref):
'''
# nor
#dp = sum(1 for r in reads if (interval_start - r[1]) <= 0 and (r[0] - interval_end) <= 0)

if uid not in uid_dp_dict:
uid_dp_dict[uid] = {}
if bam_type not in uid_dp_dict[uid]:
uid_dp_dict[uid][bam_type] = [None, None]
uid_dp_dict[uid][bam_type][edge] = str(dp)

"""
else:
# ALTERNATELY: IF USING THIS, WILL NEED TO REDUCE MAX TASKS PER CHILD CALCULATION
uid_dp_dict = {}
for bam_type, bam_filename in bam_filenames.items():
bam_file = pysam.AlignmentFile(bam_filename, "rb")
for i in intervals:
i_chrom, i_start, i_end, uid, edge = i
reads = [read for read in bam_file.fetch(i_chrom, int(i_start), int(i_end))]
reads = [read for read in reads if not read.is_duplicate and read.mapping_quality >= 0]
if uid not in uid_dp_dict:
uid_dp_dict[uid] = {}
if bam_type not in uid_dp_dict[uid]:
uid_dp_dict[uid][bam_type] = [None, None]
uid_dp_dict[uid][bam_type][int(edge)] = str(len(reads))
"""
if file_type not in uid_dp_dict[uid]:
uid_dp_dict[uid][file_type] = [None, None]
uid_dp_dict[uid][file_type][edge] = str(dp)

return uid_dp_dict

def add_local_depth_old(breakpoints, bam_filenames):
""" given breakpoints, add the local depth of tumour/normal """
for label, bam_filename in bam_filenames.items():
bam_file = pysam.AlignmentFile(bam_filename, "rb")
option_one=False
option_two=False
if option_one:
for bp in breakpoints:
coverage = bam_file.count_coverage(bp.start_chr, bp.start_loc-1, bp.start_loc)
bp.local_depths.setdefault(label,[]).append(str(np.sum(coverage)))
if not bp.source == "INS":
# add the second edge (append to list)
coverage = bam_file.count_coverage(bp.end_chr, bp.end_loc-1, bp.end_loc)
bp.local_depths.setdefault(label,[]).append(str(np.sum(coverage)))
elif option_two:
for bp in breakpoints:
got_dp = False
for pileupcolumn in bam_file.pileup(bp.start_chr, bp.start_loc-1, bp.start_loc, min_mapping_quality=0, ignore_overlaps=False):
if pileupcolumn.pos == bp.start_loc - 1:
bp.local_depths.setdefault(label,[]).append(str(pileupcolumn.n))
got_dp = True
continue
if not got_dp:
bp.local_depths.setdefault(label,[]).append(str(0))
if not bp.source == "INS":
# add the second edge (append to list)
got_dp = False
for pileupcolumn in bam_file.pileup(bp.end_chr, bp.end_loc-1, bp.end_loc, min_mapping_quality=0, ignore_overlaps=False):
if pileupcolumn.pos == bp.end_loc - 1:
bp.local_depths.setdefault(label,[]).append(str(pileupcolumn.n))
got_dp = True
continue
if not got_dp:
bp.local_depths.setdefault(label,[]).append(str(0))
else:
for bp in breakpoints:
""" e.g.) {'tumour': [start_depth, end_depth], 'normal': [start_depth, end_depth]} """
#reads = [read for read in bam_file.fetch(bp.start_chr, bp.start_loc, bp.start_loc+1, multiple_iterators=True)]
reads = [read for read in bam_file.fetch(bp.start_chr, bp.start_loc, bp.start_loc+1)]
reads = [read for read in reads if read.is_duplicate == False and read.mapping_quality >= 0]
bp.local_depths.setdefault(label,[]).append(str(len(reads)))
if not bp.source == "INS":
# add the second edge (append to list)
#reads = [read for read in bam_file.fetch(bp.end_chr, bp.end_loc, bp.end_loc+1, multiple_iterators=True)]
reads = [read for read in bam_file.fetch(bp.end_chr, bp.end_loc, bp.end_loc+1)]
reads = [read for read in reads if read.is_duplicate == False and read.mapping_quality >= 0]
bp.local_depths.setdefault(label,[]).append(str(len(reads)))
return breakpoints

def call_breakpoints(clusters, buffer, min_length, min_depth, chrom):
""" identify consensus breakpoints from list of clusters """
# N.B. all breakpoints in a cluster must be from same chromosome!
Expand Down
10 changes: 0 additions & 10 deletions savana/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,6 @@ def time_function(desc, checkpoints, time_str, final=False):
print(formatted_time)
return

def get_local_coverage(chrom, start, end, bam_files):
""" given a location, return the local coverage for each bam file in dict """
coverages = {}
for label, bam_file in bam_files.items():
reads = [read for read in bam_file.fetch(chrom, start, end, multiple_iterators=True)]
reads = [read for read in reads if read.is_duplicate == False and read.mapping_quality >= 0]
coverages[label] = len(reads)

return coverages

def check_outdir(args_outdir):
# create output dir if it doesn't exist
outdir = os.path.join(os.getcwd(), args_outdir)
Expand Down
207 changes: 42 additions & 165 deletions savana/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pybedtools

import savana.helper as helper
from savana.breakpoints import get_potential_breakpoints, call_breakpoints, add_local_depth, add_local_depth_old
from savana.breakpoints import get_potential_breakpoints, call_breakpoints, add_local_depth
from savana.clusters import cluster_breakpoints, output_clusters

# developer dependencies
Expand Down Expand Up @@ -128,131 +128,47 @@ def pool_output_clusters(args, clusters, outdir):
pool_output.close()
pool_output.join()

def pool_add_local_depth_old(threads, breakpoint_dict_chrom, bam_files):
pool_local_depth = Pool(processes=threads)
pool_local_depth_args = []
# convert bam_files into filenames (rather than objects - breaks parallelization)
for label in bam_files.keys():
bam_files[label] = bam_files[label].filename
for chrom, breakpoints in breakpoint_dict_chrom.items():
pool_local_depth_args.append((breakpoints, bam_files))
local_depth_results = pool_local_depth.starmap(add_local_depth_old, pool_local_depth_args)
pool_local_depth.close()
pool_local_depth.join()
breakpoint_by_chrom = {}
for result in local_depth_results:
chrom = result[0].start_chr
if chrom not in breakpoint_by_chrom:
breakpoint_by_chrom[chrom] = result
else:
breakpoint_by_chrom[chrom].extend(result)

return breakpoint_by_chrom

def single_add_local_depth(intervals, bam_filenames):
""" SINGLE THREAD OPTION
given intervals and uids, get the local depth for each interval """
if True:
uid_dp_dict = {}
chrom = intervals[0][0]
start = int(intervals[0][1]) # first start
end = int(intervals[-1][2]) # last end
read_stats = {}
for bam_type, bam_filename in bam_filenames.items():
with pysam.AlignmentFile(bam_filename, "rb") as bam_file:
bam_file = pysam.AlignmentFile(bam_filename, "rb")
read_stats[bam_type] = []
for read in bam_file.fetch(chrom, start, end):
continue
for read in bam_file.fetch(chrom, start, end):
if read.mapping_quality == 0 or read.is_duplicate:
continue
read_stats[bam_type].append([int(read.reference_start), int(read.reference_length)])
bam_file.close()
#print(f'Calculating DP for {len(intervals)} Intervals')
for i in intervals:
uid = i[3]
interval_end = int(i[2])
edge = int(i[4])
for bam_type, reads in read_stats.items():
subtraction = [((interval_end-r[0])+r[1]) for r in reads]
dp = sum(1 for x in subtraction if x >= 0)
del subtraction
if uid not in uid_dp_dict:
uid_dp_dict[uid] = {}
if bam_type not in uid_dp_dict[uid]:
uid_dp_dict[uid][bam_type] = [None, None]
uid_dp_dict[uid][bam_type][edge] = str(dp)
#print(f'Done calculating DP for {len(intervals)} Intervals')

else:
# ALTERNATELY
uid_dp_dict = {}
for bam_type, bam_filename in bam_filenames.items():
bam_file = pysam.AlignmentFile(bam_filename, "rb")
for i in intervals:
i_chrom, i_start, i_end, uid, edge = i
reads = [read for read in bam_file.fetch(i_chrom, int(i_start), int(i_end))]
reads = [read for read in reads if not read.is_duplicate and read.mapping_quality >= 0]
if uid not in uid_dp_dict:
uid_dp_dict[uid] = {}
if bam_type not in uid_dp_dict[uid]:
uid_dp_dict[uid][bam_type] = [None, None]
uid_dp_dict[uid][bam_type][int(edge)] = str(len(reads))
bam_file.close()

return uid_dp_dict

def pool_add_local_depth(threads, sorted_bed, breakpoint_dict_chrom, aln_files, is_cram=False, ref=False):
""" """
from itertools import groupby

# OPTION FOR THREADS = 1
if threads == 1:
intervals_by_chrom = sorted([list(intervals) for _chrom, intervals in groupby(sorted_bed, lambda x: x[0])], key=len)
local_depth_results = []
for label in aln_files.keys():
aln_files[label] = aln_files[label].filename
for chrom_split in intervals_by_chrom:
local_depth_results.append(single_add_local_depth(chrom_split, aln_files))
else:
intervals_by_chrom = sorted([list(intervals) for _chrom, intervals in groupby(sorted_bed, lambda x: x[0])], key=len, reverse=True)
total_length = sum([len(c) for c in intervals_by_chrom])
redistributed_intervals = []
#ideal_binsize = floor(total_length/(threads-2))
#ideal_binsize = floor(total_length/(threads*2))
ideal_binsize = max(floor(total_length/(threads*threads)),1)
for chrom_chunk in intervals_by_chrom:
if len(chrom_chunk) > 2*ideal_binsize:
num_subchunks = floor(len(chrom_chunk)/ideal_binsize)
# split list into equal chunks from https://stackoverflow.com/a/2135920
quotient, remainder = divmod(len(chrom_chunk), num_subchunks)
chunk_split = (chrom_chunk[i*quotient+min(i, remainder):(i+1)*quotient+min(i+1, remainder)] for i in range(num_subchunks))
redistributed_intervals.extend(chunk_split)
else:
# don't bother splitting
redistributed_intervals.append(chrom_chunk)
print(f'Using {ideal_binsize} as binsize, there are {len(redistributed_intervals)} redistributed intervals')
if not redistributed_intervals:
import sys
sys.exit('Issue calculating redistributed_intervals. Check input parameters')
max_bin = max([len(c) for c in redistributed_intervals])
min_bin = min([len(c) for c in redistributed_intervals])
print(f'Max binsize {max_bin}, min binsize {min_bin}')

# calculate max tasksperchild (max num)
max_total_intervals_per_child = max(10000, max_bin+1) # figure this out by experiements
max_tasks = floor(max_total_intervals_per_child/max_bin)
print(f'Setting maxtasksperchild to {max_tasks}')

pool_local_depth = Pool(processes=threads, maxtasksperchild=max_tasks)
pool_local_depth_args = []
# convert bam_files into filenames (rather than objects - breaks parallelization)
for label in aln_files.keys():
aln_files[label] = aln_files[label].filename
for chrom_split in redistributed_intervals:
pool_local_depth_args.append((chrom_split, aln_files, is_cram, ref))
local_depth_results = pool_local_depth.starmap(add_local_depth, pool_local_depth_args)
intervals_by_chrom = sorted([list(intervals) for _chrom, intervals in groupby(sorted_bed, lambda x: x[0])], key=len, reverse=True)
total_length = sum([len(c) for c in intervals_by_chrom])
redistributed_intervals = []
#ideal_binsize = floor(total_length/(threads-2))
#ideal_binsize = floor(total_length/(threads*2))
ideal_binsize = max(floor(total_length/(threads*threads)),1)
for chrom_chunk in intervals_by_chrom:
if len(chrom_chunk) > 2*ideal_binsize:
num_subchunks = floor(len(chrom_chunk)/ideal_binsize)
# split list into equal chunks from https://stackoverflow.com/a/2135920
quotient, remainder = divmod(len(chrom_chunk), num_subchunks)
chunk_split = (chrom_chunk[i*quotient+min(i, remainder):(i+1)*quotient+min(i+1, remainder)] for i in range(num_subchunks))
redistributed_intervals.extend(chunk_split)
else:
# don't bother splitting
redistributed_intervals.append(chrom_chunk)
print(f'Using {ideal_binsize} as binsize, there are {len(redistributed_intervals)} redistributed intervals')
if not redistributed_intervals:
import sys
sys.exit('Issue calculating redistributed_intervals. Check input parameters')
max_bin = max([len(c) for c in redistributed_intervals])
min_bin = min([len(c) for c in redistributed_intervals])
print(f'Max binsize {max_bin}, min binsize {min_bin}')

# calculate max tasksperchild (max num)
max_total_intervals_per_child = max(10000, max_bin+1) # figure this out by experiements
max_tasks = floor(max_total_intervals_per_child/max_bin)
print(f'Setting maxtasksperchild to {max_tasks}')

pool_local_depth = Pool(processes=threads, maxtasksperchild=max_tasks)
pool_local_depth_args = []
# convert aln_files into filenames (rather than objects - breaks parallelization)
for label in aln_files.keys():
aln_files[label] = aln_files[label].filename
for chrom_split in redistributed_intervals:
pool_local_depth_args.append((chrom_split, aln_files, is_cram, ref))
local_depth_results = pool_local_depth.starmap(add_local_depth, pool_local_depth_args)

uid_dp_dict = {}
"""
Expand All @@ -265,11 +181,11 @@ def pool_add_local_depth(threads, sorted_bed, breakpoint_dict_chrom, aln_files,
for uid, counts in result.items():
if uid not in uid_dp_dict:
uid_dp_dict[uid] = {}
for bam_file, values in counts.items():
if bam_file not in uid_dp_dict[uid]:
uid_dp_dict[uid][bam_file] = [None, None]
for aln_file, values in counts.items():
if aln_file not in uid_dp_dict[uid]:
uid_dp_dict[uid][aln_file] = [None, None]
for i, dp in enumerate(values):
uid_dp_dict[uid][bam_file][i] = dp if dp else uid_dp_dict[uid][bam_file][i]
uid_dp_dict[uid][aln_file][i] = dp if dp else uid_dp_dict[uid][aln_file][i]

for breakpoints in breakpoint_dict_chrom.values():
for bp in breakpoints:
Expand Down Expand Up @@ -308,45 +224,6 @@ def pool_call_breakpoints(threads, buffer, length, depth, clusters, debug):

return breakpoint_dict_chrom, pruned_clusters

def profiling_experiment(intervals, bam_filenames):
""" to identify and characterise potential memory leak in pysam """
import gc
import sys
chrom = intervals[0][0]
start = int(intervals[0][1]) # first start
end = int(intervals[-1][2]) # last end
# INVENTORY OF OBJECTS BEFORE
objects_before = muppy.get_objects()
print(f'Num. objects before: {len(objects_before)}')
summary_before = summary.summarize(objects_before)
objgraph.show_growth(limit=3)
for bam_type, bam_filename in bam_filenames.items():
with pysam.AlignmentFile(bam_filename, "rb") as bam_file:
bam_file = pysam.AlignmentFile(bam_filename, "rb")
# ITERATE THROUGH FETCHED READS
count_reads = 0
for read in bam_file.fetch(chrom, start, end):
count_reads+=1
continue
#del read
print(f'Num reads: {count_reads}')
#del bam_file
gc.collect()
objgraph.show_growth(limit=3)
# INVENTORY OF OBJECTS AFTER
objects_after = muppy.get_objects()
print(f'Num. objects after: {len(objects_after)}')
summary_after = summary.summarize(objects_after)
summary.print_(summary_after)
print('Diff:')
diff = summary.get_diff(summary_before, summary_after)
summary.print_(diff)
pysam_objs = muppy.filter(objects_after, Type=(pysam.libcalignedsegment.AlignedSegment))
print(f'Num Aligned Segment Objects: {len(pysam_objs)}')
for o in pysam_objs:
cb = refbrowser.ConsoleBrowser(o, maxdepth=2, str_func=lambda x: str(type(x)))
cb.print_tree()

def spawn_processes(args, aln_files, checkpoints, time_str, outdir):
""" run main algorithm steps in parallel processes """
print(f'Using multiprocessing with {args.threads} threads\n')
Expand Down
1 change: 0 additions & 1 deletion savana/savana.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def savana_main(args):
args.stats = os.path.join(args.outdir,f'{args.sample}.evaluation.stats')
savana_evaluate(args)


def main():
""" main function for SAVANA - collects command line arguments and executes algorithm """

Expand Down

0 comments on commit 041435b

Please sign in to comment.