Skip to content

Commit

Permalink
ALS-5819: Implement intersection. Some refactoring and unit testing too
Browse files Browse the repository at this point in the history
  • Loading branch information
ramari16 committed Feb 26, 2024
1 parent c9a7579 commit bdfb093
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,28 +152,6 @@ public static BigInteger emptyBitmask(int length) {
return emptyBitmask;
}

/**
* Appends one mask to another. This assumes the masks are both padded with '11' on each end
* to prevent overflow issues.
*/
public BigInteger appendMask(BigInteger mask1, int mask1Length, BigInteger mask2, int mask2length) {
if (mask1 == null && mask2 == null) {
return null;
}
if (mask1 == null) {
// todo: unit test this funcitonality
mask1 = emptyBitmask(mask1Length);
}
if (mask2 == null) {
mask2 = emptyBitmask(mask2length);
}
String binaryMask1 = mask1.toString(2);
String binaryMask2 = mask2.toString(2);
String appendedString = binaryMask1.substring(0, binaryMask1.length() - 2) +
binaryMask2.substring(2);
return new BigInteger(appendedString, 2);
}

public VariableVariantMasks append(VariableVariantMasks variantMasks) {
VariableVariantMasks appendedMasks = new VariableVariantMasks();
appendedMasks.homozygousMask = appendMask(this.homozygousMask, variantMasks.homozygousMask, this.length, variantMasks.length);
Expand All @@ -193,7 +171,15 @@ public static VariantMask appendMask(VariantMask variantMask1, VariantMask varia
throw new RuntimeException("Unknown VariantMask implementation");
}
}
// todo: bitmask
else if (variantMask1 instanceof VariantMaskBitmaskImpl) {
if (variantMask2 instanceof VariantMaskSparseImpl) {
return append((VariantMaskBitmaskImpl) variantMask1, (VariantMaskSparseImpl) variantMask2, length1, length2);
} else if (variantMask2 instanceof VariantMaskBitmaskImpl) {
return append((VariantMaskBitmaskImpl) variantMask1, (VariantMaskBitmaskImpl) variantMask2, length1, length2);
} else {
throw new RuntimeException("Unknown VariantMask implementation");
}
}
else {
throw new RuntimeException("Unknown VariantMask implementation");
}
Expand All @@ -210,9 +196,32 @@ private static VariantMask append(VariantMaskSparseImpl variantMask1, VariantMas
binaryMask1.substring(2);
return new VariantMaskBitmaskImpl(new BigInteger(appendedString, 2));
}

private static VariantMask append(VariantMaskBitmaskImpl variantMask1, VariantMaskSparseImpl variantMask2, int length1, int length2) {
String binaryMask1 = variantMask1.bitmask.toString(2);

BigInteger mask2 = emptyBitmask(length2);
for (Integer patientId : variantMask2.patientIndexes) {
mask2 = mask2.setBit(patientId);
}
String binaryMask2 = mask2.toString(2);

String appendedString = binaryMask2.substring(0, binaryMask1.length() - 2) +
binaryMask1.substring(2);
return new VariantMaskBitmaskImpl(new BigInteger(appendedString, 2));
}

private static VariantMask append(VariantMaskBitmaskImpl variantMask1, VariantMaskBitmaskImpl variantMask2, int length1, int length2) {
String binaryMask1 = variantMask1.bitmask.toString(2);
String binaryMask2 = variantMask2.bitmask.toString(2);

String appendedString = binaryMask2.substring(0, binaryMask1.length() - 2) +
binaryMask1.substring(2);
return new VariantMaskBitmaskImpl(new BigInteger(appendedString, 2));
}

private static VariantMask append(VariantMaskSparseImpl variantMask1, VariantMaskSparseImpl variantMask2, int length1, int length2) {
if (variantMask1.patientIndexes.size() + variantMask2.patientIndexes.size() > SPARSE_VARIANT_THRESHOLD) {
// todo: performance test this vs byte array
BigInteger mask = emptyBitmask(length1 + length2);
for (Integer patientId : variantMask1.patientIndexes) {
mask = mask.setBit(patientId + 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import java.math.BigInteger;
import java.util.Set;
import java.util.stream.Collectors;

@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
Expand All @@ -29,12 +30,4 @@ public interface VariantMask {
static VariantMask emptyInstance() {
return new VariantMaskSparseImpl(Set.of());
}

static VariantMask union(VariantMaskSparseImpl variantMaskSparse, VariantMaskBitmaskImpl variantMaskBitmask) {
BigInteger union = variantMaskBitmask.bitmask;
for (Integer patientId : variantMaskSparse.patientIndexes) {
union = union.setBit(patientId + 2);
}
return new VariantMaskBitmaskImpl(union);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,29 @@ public VariantMaskBitmaskImpl(@JsonProperty("mask") BigInteger bitmask) {

@Override
public VariantMask intersection(VariantMask variantMask) {
throw new RuntimeException("Not implemented yet");
if (variantMask instanceof VariantMaskBitmaskImpl) {
return intersection((VariantMaskBitmaskImpl) variantMask);
} else if (variantMask instanceof VariantMaskSparseImpl) {
return variantMask.intersection(this);
} else {
throw new RuntimeException("Unknown VariantMask implementation");
}
}

@Override
public VariantMask union(VariantMask variantMask) {
if (variantMask instanceof VariantMaskBitmaskImpl) {
return union((VariantMaskBitmaskImpl) variantMask);
} else if (variantMask instanceof VariantMaskSparseImpl) {
return VariantMask.union((VariantMaskSparseImpl) variantMask, this);
return variantMask.union(this);
} else {
throw new RuntimeException("Unknown VariantMask implementation");
}
}

@Override
public boolean testBit(int bit) {
return bitmask.testBit(bit);
return bitmask.testBit(bit + 2);
}

@Override
Expand All @@ -51,6 +57,11 @@ public int bitCount() {
}

private VariantMask union(VariantMaskBitmaskImpl variantMaskBitmask) {
return new VariantMaskBitmaskImpl(variantMaskBitmask.bitmask.or(this.bitmask));
}
private VariantMask intersection(VariantMaskBitmaskImpl variantMaskBitmask) {
// we could consider using a sparse variant index here if we are ever going to be storing the
// result of this anywhere
return new VariantMaskBitmaskImpl(variantMaskBitmask.bitmask.and(this.bitmask));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import com.fasterxml.jackson.annotation.JsonProperty;

import java.math.BigInteger;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class VariantMaskSparseImpl implements VariantMask {

Expand All @@ -21,13 +23,15 @@ public Set<Integer> getPatientIndexes() {

@Override
public VariantMask intersection(VariantMask variantMask) {
throw new RuntimeException("Not implemented yet");
return new VariantMaskSparseImpl(this.patientIndexes.stream()
.filter(variantMask::testBit)
.collect(Collectors.toSet()));
}

@Override
public VariantMask union(VariantMask variantMask) {
if (variantMask instanceof VariantMaskBitmaskImpl) {
return VariantMask.union(this, (VariantMaskBitmaskImpl) variantMask);
return union((VariantMaskBitmaskImpl) variantMask);
} else if (variantMask instanceof VariantMaskSparseImpl) {
return union((VariantMaskSparseImpl) variantMask);
} else {
Expand All @@ -51,6 +55,14 @@ private VariantMask union(VariantMaskSparseImpl variantMask) {
return new VariantMaskSparseImpl(union);
}

private VariantMask union(VariantMaskBitmaskImpl variantMaskBitmask) {
BigInteger union = variantMaskBitmask.bitmask;
for (Integer patientId : this.patientIndexes) {
union = union.setBit(patientId + 2);
}
return new VariantMaskBitmaskImpl(union);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package edu.harvard.hms.dbmi.avillach.hpds.data.genotype;

import org.junit.jupiter.api.Test;

import java.math.BigInteger;
import java.util.Set;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class VariantMaskTest {

@Test
public void intersection_bitmaskVsBitmask() {
VariantMask mask1 = new VariantMaskBitmaskImpl(new BigInteger("111001100011", 2));
VariantMask mask2 = new VariantMaskBitmaskImpl(new BigInteger("111010010011", 2));
VariantMask expected = new VariantMaskBitmaskImpl(new BigInteger("111000000011", 2));

assertEquals(expected, mask1.intersection(mask2));
}
@Test
public void intersection_bitmaskVsSparse() {
// this is essentially a mask for patients 0, 2, 3, 7 (there is 11 padding on both ends)
VariantMask mask1 = new VariantMaskBitmaskImpl(new BigInteger("111000110111", 2));
VariantMask mask2 = new VariantMaskSparseImpl(Set.of(0, 3, 6));
VariantMask expected = new VariantMaskSparseImpl(Set.of(0, 3));

assertEquals(expected, mask1.intersection(mask2));
}
@Test
public void intersection_sparseVsBitmask() {
VariantMask mask1 = new VariantMaskSparseImpl(Set.of(4, 7));
VariantMask mask2 = new VariantMaskBitmaskImpl(new BigInteger("110111110111", 2));
VariantMask expected = new VariantMaskSparseImpl(Set.of(4));

assertEquals(expected, mask1.intersection(mask2));
}

@Test
public void intersection_sparseVsSparse() {
VariantMask mask1 = new VariantMaskSparseImpl(Set.of(0, 2, 4, 6));
VariantMask mask2 = new VariantMaskSparseImpl(Set.of(0, 1, 3, 5, 7));
VariantMask expected = new VariantMaskSparseImpl(Set.of(0));

assertEquals(expected, mask1.intersection(mask2));
}

@Test
public void union_bitmaskVsBitmask() {
VariantMask mask1 = new VariantMaskBitmaskImpl(new BigInteger("111001100011", 2));
VariantMask mask2 = new VariantMaskBitmaskImpl(new BigInteger("111010010011", 2));
VariantMask expected = new VariantMaskBitmaskImpl(new BigInteger("111011110011", 2));

assertEquals(expected, mask1.union(mask2));
}

@Test
public void union_bitmaskVsSparse() {
// this is essentially a mask for patients 0, 2, 3, 7 (there is 11 padding on both ends)
VariantMask mask1 = new VariantMaskBitmaskImpl(new BigInteger("111000110111", 2));
VariantMask mask2 = new VariantMaskSparseImpl(Set.of(0, 3, 6));
VariantMask expected = new VariantMaskBitmaskImpl(new BigInteger("111100110111", 2));

assertEquals(expected, mask1.union(mask2));
}

@Test
public void union_sparseVsBitmask() {
VariantMask mask1 = new VariantMaskSparseImpl(Set.of(4, 7));
VariantMask mask2 = new VariantMaskBitmaskImpl(new BigInteger("110111110111", 2));
VariantMask expected = new VariantMaskBitmaskImpl(new BigInteger("111111110111", 2));

assertEquals(expected, mask1.union(mask2));
}

@Test
public void union_sparseVsSparse() {
VariantMask mask1 = new VariantMaskSparseImpl(Set.of(0, 2, 4, 6));
VariantMask mask2 = new VariantMaskSparseImpl(Set.of(1, 5, 7));
VariantMask expected = new VariantMaskSparseImpl(Set.of(0, 1, 2, 4, 5, 6, 7));

assertEquals(expected, mask1.union(mask2));
}
}

0 comments on commit bdfb093

Please sign in to comment.