Skip to content

Commit

Permalink
Optimised NonReferenceContigAssembler.graphByKmerNode
Browse files Browse the repository at this point in the history
High coverage RP kmers had different quals in each position meaning a O(n^2) array traversal.
Replaced SortedSet using to reduce cost O(nlogn)
Optimisation for n=1 to avoid creating a SortedSet at all (directly storing the KmerNode in the top-level Long2ObjectMap lookup)
  • Loading branch information
Daniel Cameron committed Mar 17, 2021
1 parent 3010c03 commit a34993a
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import au.edu.wehi.idsv.bed.IntervalBed;
import au.edu.wehi.idsv.debruijn.DeBruijnGraphBase;
import au.edu.wehi.idsv.debruijn.KmerEncodingHelper;
import au.edu.wehi.idsv.debruijn.positional.optimiseddatastructures.KmerNodeByLastKmerIntervalLookup;
import au.edu.wehi.idsv.graph.ScalingHelper;
import au.edu.wehi.idsv.model.Models;
import au.edu.wehi.idsv.util.FilenameUtil;
Expand Down Expand Up @@ -69,9 +70,7 @@ public class NonReferenceContigAssembler implements Iterator<SAMRecord> {
* expensive approach overall
*/
private static final boolean SIMPLIFY_AFTER_REMOVAL = false;
// TODO: OPT: don't use ArrayList<>() as child structure
// sort by end position so we can do fast overlap calculations
private Long2ObjectMap<Collection<KmerPathNodeKmerNode>> graphByKmerNode = new Long2ObjectOpenHashMap<Collection<KmerPathNodeKmerNode>>();
private KmerNodeByLastKmerIntervalLookup<KmerPathNodeKmerNode> graphByKmerNode = new KmerNodeByLastKmerIntervalLookup<>();
private TreeSet<KmerPathNode> graphByPosition = new TreeSet<KmerPathNode>(KmerNodeUtil.ByFirstStartKmer); // TODO: OPT: replace data structure
private SortedSet<KmerPathNode> nonReferenceGraphByPosition = new TreeSet<KmerPathNode>(KmerNodeUtil.ByFirstStartKmer); // TODO: OPT: replace data structure
private final EvidenceTracker evidenceTracker;
Expand Down Expand Up @@ -857,15 +856,8 @@ private void updateRemovalList(Map<KmerPathNode, List<List<KmerNode>>> toRemove,
}
}
private void updateRemovalList(Map<KmerPathNode, List<List<KmerNode>>> toRemove, KmerSupportNode support) {
Collection<KmerPathNodeKmerNode> kpnknList = graphByKmerNode.get(support.lastKmer());
if (kpnknList != null) {
// TODO: secondary sort order on graphByKmerNode so we can subset this iterator to only
// the overlapping nodes
for (KmerPathNodeKmerNode n : kpnknList) {
if (IntervalUtil.overlapsClosed(support.lastStart(), support.lastEnd(), n.lastStart(), n.lastEnd())) {
updateRemovalList(toRemove, n, support);
}
}
for (KmerPathNodeKmerNode n : graphByKmerNode.getOverlapping(support.lastKmer(), support.lastStart(), support.lastEnd())) {
updateRemovalList(toRemove, n, support);
}
}
private void updateRemovalList(Map<KmerPathNode, List<List<KmerNode>>> toRemove, KmerPathNodeKmerNode node, KmerSupportNode support) {
Expand Down Expand Up @@ -934,20 +926,10 @@ private void removeFromGraph(KmerPathNode node, boolean includeMemoizationRemova
}
}
private void addToGraph(KmerPathNodeKmerNode node) {
Collection<KmerPathNodeKmerNode> list = graphByKmerNode.get(node.firstKmer());
if (list == null) {
list = new ArrayList<>();
graphByKmerNode.put(node.firstKmer(), list);
}
list.add(node);
graphByKmerNode.add(node);
}
private void removeFromGraph(KmerPathNodeKmerNode node) {
Collection<KmerPathNodeKmerNode> list = graphByKmerNode.get(node.firstKmer());
if (list == null) return;
list.remove(node);
if (list.size() == 0) {
graphByKmerNode.remove(node.firstKmer());
}
graphByKmerNode.remove(node);
}

/**
Expand Down Expand Up @@ -1022,7 +1004,7 @@ private Range<Integer> readPairEvidence(KmerEvidence e) {
return bounds;
}
/**
* Determins where offset in the given evidence is included in the assembly.
* Determines where offset in the given evidence is included in the assembly.
* Read pairs can overlap multiple times if the sequence kmer is repeated.
* Since we don't actually know which position we placed the read in, we'll return them all.
*/
Expand Down Expand Up @@ -1076,7 +1058,7 @@ private void addToLookup(int offset, long kmer, int start, int end, Long2ObjectO
}
}
public boolean sanityCheck() {
graphByKmerNode.long2ObjectEntrySet().stream().flatMap(e -> e.getValue().stream()).forEach(kn -> {
graphByKmerNode.stream().forEach(kn -> {
assert(kn.node().isValid());
assert(graphByPosition.contains(kn.node()));
});
Expand Down Expand Up @@ -1219,7 +1201,7 @@ public int tracking_activeNodes() {
return graphByPosition.size();
}
public int tracking_maxKmerActiveNodeCount() {
return graphByKmerNode.values().stream().mapToInt(x -> x.size()).max().orElse(0);
return (int)graphByKmerNode.stream().count();
}
public long tracking_underlyingConsumed() {
return consumed;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package au.edu.wehi.idsv.debruijn.positional.optimiseddatastructures;

import au.edu.wehi.idsv.debruijn.KmerEncodingHelper;
import au.edu.wehi.idsv.debruijn.positional.KmerNode;
import au.edu.wehi.idsv.util.IntervalUtil;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Kmer interval lookup in which no records overlap
*/
public abstract class KmerIntervalLookup<T> {
/**
* Lookup of node starts. Secondary key is end()
* Values are the KmerNode itself (when there is only 1) or a Int2ObjectSortedMap
*/
private final Long2ObjectOpenHashMap<Object> kmerLookup = new Long2ObjectOpenHashMap<>();
private int size = 0;
protected abstract int getStart(T node);
protected abstract int getEnd(T node);
protected abstract long getKmer(T node);
public void add(T node) {
long kmer = getKmer(node);
Object x = kmerLookup.get(kmer);
Int2ObjectSortedMap<T> positionLookup;
if (x == null) {
kmerLookup.put(kmer, node);
return;
} else if (x instanceof Int2ObjectSortedMap) {
positionLookup = (Int2ObjectSortedMap<T>) x;
} else {
T existing = (T)x;
positionLookup = new Int2ObjectRBTreeMap<>();
kmerLookup.put(kmer, positionLookup);
positionLookup.put(getEnd(existing), existing);
}
positionLookup.put(getEnd(node), node);
}
public void remove(T node) {
long kmer = getKmer(node);
Object x = kmerLookup.get(kmer);
if (x instanceof Int2ObjectSortedMap) {
Int2ObjectSortedMap<T> positionLookup = (Int2ObjectSortedMap<T>) x;
T found = positionLookup.remove(getEnd(node));
assert (found != null);
if (positionLookup.isEmpty()) {
kmerLookup.remove(kmer);
}
} else {
Object found = kmerLookup.remove(kmer);
assert(found != null);
}
}
/**
* Gets the KmerNode that overlaps exactly
* @param kmer
* @param start
* @param end
* @return
*/
public T get(long kmer, int start, int end) {
Object x = kmerLookup.get(kmer);
T node = null;
if (x instanceof Int2ObjectSortedMap) {
Int2ObjectSortedMap<T> positionLookup = (Int2ObjectSortedMap<T>) x;
node = positionLookup.get(end);
} else {
node = (T)x;
}
if (node != null && (getStart(node) != start || getEnd(node) != end)) {
// doesn't overlap exactly
return null;
}
return node;
}
public List<T> getOverlapping(long kmer, int start, int end) {
Object x = kmerLookup.get(kmer);
T node = null;
if (x == null) {
return Collections.EMPTY_LIST;
} if (x instanceof Int2ObjectSortedMap) {
Int2ObjectSortedMap<T> positionLookup = (Int2ObjectSortedMap<T>) x;
positionLookup = positionLookup.tailMap(start);
ArrayList result = new ArrayList();
Iterator<T> it = positionLookup.values().stream().iterator();
while (it.hasNext()) {
node = it.next();
if (IntervalUtil.overlapsClosed(start, end, getStart(node), getEnd(node))) {
result.add(node);
} else {
break;
}
}
return result;
} else {
node = (T)x;
if (IntervalUtil.overlapsClosed(start, end, getStart(node), getEnd(node))) {
return ImmutableList.of(node);
}
return Collections.EMPTY_LIST;
}
}
public Stream<T> stream() {
return kmerLookup.values().stream().flatMap(o -> stream(o));
}
private Stream<T> stream(Object x) {
if (x == null) return Stream.empty();
if (x instanceof Int2ObjectSortedMap) return ((Int2ObjectSortedMap<T>)x).values().stream();
return Stream.of((T)x);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package au.edu.wehi.idsv.debruijn.positional.optimiseddatastructures;

import au.edu.wehi.idsv.debruijn.positional.KmerNode;

public class KmerNodeByLastKmerIntervalLookup<T extends KmerNode> extends KmerIntervalLookup<T> {
@Override
protected int getStart(T node) { return node.lastStart(); }

@Override
protected int getEnd(T node) {
return node.lastEnd();
}

@Override
protected long getKmer(T node) {
return node.lastKmer();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package au.edu.wehi.idsv.debruijn.positional.optimiseddatastructures;

import au.edu.wehi.idsv.TestHelper;
import au.edu.wehi.idsv.debruijn.KmerEncodingHelper;
import au.edu.wehi.idsv.debruijn.positional.ImmutableKmerNode;
import au.edu.wehi.idsv.debruijn.positional.KmerNode;
import au.edu.wehi.idsv.util.IntervalUtil;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class KmerNodeByFirstKmerIntervalLookupTest extends TestHelper {
private static final int k = 4;
private KmerNode kn(String kmer, int start, int end) {
return kn(KmerEncodingHelper.picardBaseToEncoded(k, kmer.getBytes()), start, end);
}
private KmerNode kn(long kmer, int start, int end) {
return new ImmutableKmerNode(kmer, start, end, false, 1);
}
@Test
public void get_should_match_overlap_logic() {
for (List<KmerNode> kns : ImmutableList.of(
ImmutableList.of(
kn(0, 1, 2),
kn(0, 4, 5),
kn(0, 6, 6),
kn(0, 7, 7),
kn(0, 9, 9)),
ImmutableList.of(
kn(0, 1, 3),
kn(0, 7, 7)),
ImmutableList.of(
kn(0, 1, 2)),
ImmutableList.of(
kn(0, 1, 2),
kn(1, 1, 2))
)) {
KmerNodeByLastKmerIntervalLookup<KmerNode> lookup = new KmerNodeByLastKmerIntervalLookup<>();
kns.stream().forEach(n -> lookup.add(n));
validate_against_direct_comparison(lookup, kns);
}
}

private void validate_against_direct_comparison(KmerNodeByLastKmerIntervalLookup<KmerNode> lookup, List<KmerNode> kns) {
for (long kmer : new long[] { 0, 1, 2}) {
for (int i = -1; i < 11; i++) {
for (int j = i; j < 12; j++) {
int start = i;
int end = j;
List<KmerNode> expected = kns.stream()
.filter(n -> n.firstKmer() == kmer && IntervalUtil.overlapsClosed(start, end, n.firstStart(), n.firstEnd()))
.collect(Collectors.toList());
Assert.assertEquals(expected, lookup.getOverlapping(kmer, i, j));
KmerNode exactMatch = kns.stream()
.filter(n -> n.firstKmer() == kmer && n.firstStart() == start && n.firstEnd() == end)
.findFirst().orElse(null);
Assert.assertEquals(exactMatch, lookup.get(kmer, i, j));
}
}
}
}

@Test
public void remove_should_match_overlap_logic() {
for (List<KmerNode> full : ImmutableList.of(
ImmutableList.of(
kn(0, 1, 2),
kn(0, 4, 5),
kn(0, 6, 6),
kn(0, 7, 7),
kn(0, 9, 9)),
ImmutableList.of(
kn(0, 1, 3),
kn(0, 7, 7)),
ImmutableList.of(
kn(0, 1, 2)),
ImmutableList.of(
kn(0, 1, 2),
kn(1, 1, 2))
)) {
for (int offset = 0; offset < full.size(); offset++) {
ArrayList<KmerNode> kns = Lists.newArrayList(full);
KmerNodeByLastKmerIntervalLookup<KmerNode> lookup = new KmerNodeByLastKmerIntervalLookup<>();
full.stream().forEach(n -> lookup.add(n));
// remove each element
KmerNode removed = kns.remove(offset);
lookup.remove(removed);
validate_against_direct_comparison(lookup, kns);
for (int offset2 = 0; offset2 < kns.size(); offset2++) {
KmerNodeByLastKmerIntervalLookup lookup2 = new KmerNodeByLastKmerIntervalLookup<>();
full.stream().forEach(n -> lookup2.add(n));
ArrayList<KmerNode> kns2 = Lists.newArrayList(kns);
KmerNode removed2 = kns2.remove(offset2);
lookup2.remove(removed);
lookup2.remove(removed2);
validate_against_direct_comparison(lookup2, kns2);
}
}
}
}
}

0 comments on commit a34993a

Please sign in to comment.