Skip to content

Commit

Permalink
remap indexes for GenericBastaLikelihoodDelegate; dynamically allocat…
Browse files Browse the repository at this point in the history
…e memories for beast/beagle implementation
  • Loading branch information
yucais committed Nov 7, 2024
1 parent 1a23c4d commit 824a635
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 34 deletions.
14 changes: 12 additions & 2 deletions src/beagle/basta/BastaJNIImpl.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ public BastaJNIImpl(int tipCount,
super(tipCount, partialsBufferCount, compactBufferCount, stateCount, patternCount, eigenBufferCount,
matrixBufferCount, categoryCount, scaleBufferCount, resourceList, preferenceFlags, requirementFlags);

allocateCoalescentBuffers(coalescentBufferCount, maxCoalescentIntervalCount, partialsBufferCount,1);
}

@Override
public void allocateCoalescentBuffers(int coalescentBufferCount, int maxCoalescentIntervalCount, int partialsBufferCount, int initial) {
int errCode = BastaJNIWrapper.INSTANCE.allocateCoalescentBuffers(instance, coalescentBufferCount,
maxCoalescentIntervalCount);
maxCoalescentIntervalCount, partialsBufferCount, initial);
if (errCode != 0) {
throw new BeagleException("constructor", errCode);
throw new BeagleException("allocateCoalescentBuffers", errCode);
}
}

Expand Down Expand Up @@ -85,4 +90,9 @@ public void accumulateBastaPartials(int[] operations, int operationCount, int[]
throw new BeagleException("accumulateBastaPartials", errCode);
}
}

public int getInstance() {
return this.instance;
}

}
4 changes: 3 additions & 1 deletion src/beagle/basta/BastaJNIWrapper.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ private BastaJNIWrapper() { }

public native int allocateCoalescentBuffers(int instance,
int bufferCount,
int maxCoalescentIntervalCount); // TODO buffers have different sizes
int maxCoalescentIntervalCount,
int partialsBufferCount,
int initial); // TODO buffers have different sizes

public native int getBastaBuffer(int instance,
int index,
Expand Down
2 changes: 2 additions & 0 deletions src/beagle/basta/BeagleBasta.java
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ void accumulateBastaPartials(final int[] operations,
int coalescentIndex,
double[] result);

void allocateCoalescentBuffers(int coalescentBufferCount, int maxCoalescentIntervalCount, int partialsBufferCount, int initial);

void getBastaBuffer(int index, double[] buffer);

void updateBastaPartialsGrad(int[] operations, int operationCount, int[] intervals, int intervalCount, int populationSizeIndex, int coalescentProbabilityIndex);
Expand Down
7 changes: 5 additions & 2 deletions src/dr/evomodel/coalescent/basta/BastaInternalStorage.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,21 @@ public BastaInternalStorage(int maxNumCoalescentIntervals, int treeNodeCount, in
this.sizes = new double[2 * stateCount];
this.decompositions = new EigenDecomposition[1];

resize(getStartingPartialsCount(maxNumCoalescentIntervals, treeNodeCount), maxNumCoalescentIntervals);
resize(3 * treeNodeCount, maxNumCoalescentIntervals, null);
}

static private int getStartingPartialsCount(int maxNumCoalescentIntervals, int treeNodeCount) {
return maxNumCoalescentIntervals * (treeNodeCount + 1); // TODO much too large
}

public void resize(int newNumPartials, int newNumCoalescentIntervals) {
public void resize(int newNumPartials, int newNumCoalescentIntervals, BastaLikelihood likelihood) {

if (newNumPartials > currentNumPartials) {
this.partials = new double[newNumPartials * stateCount];
this.currentNumPartials = newNumPartials;
if (likelihood != null) {
likelihood.setTipData();
}
}

if (newNumCoalescentIntervals > this.currentNumCoalescentIntervals) {
Expand Down
7 changes: 4 additions & 3 deletions src/dr/evomodel/coalescent/basta/BastaLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public BastaLikelihood(String name,

public SubstitutionModel getSubstitutionModel() { return substitutionModel; } // TODO generify for multiple models (e.g. epochs)

private void setTipData() {
public void setTipData() {

int[] data = patternList.getPattern(0);

Expand Down Expand Up @@ -292,7 +292,7 @@ private double calculateLogLikelihood() {

final NodeRef root = tree.getRoot();
double logL = likelihoodDelegate.calculateLikelihood(branchOperations, matrixOperations,
intervalStarts, root.getNumber());
intervalStarts, root.getNumber(), this);

// after traverse all nodes and patterns have been updated --
//so change flags to reflect this.
Expand All @@ -319,7 +319,7 @@ public double[] getGradientLogDensity(StructuredCoalescentLikelihoodGradient wrt
calculateLogLikelihood(); // TODO Only execute if necessary

double[] gradient = likelihoodDelegate.calculateGradient(branchOperations, matrixOperations, intervalStarts,
root.getNumber(), wrt);
root.getNumber(), wrt, this);

return wrt.chainRule(gradient);
}
Expand Down Expand Up @@ -371,6 +371,7 @@ public String getReport() {
"\n partial rate updates = ").append(totalRateUpdateSingleCount).append(
"\n average likelihood time = ").append(totalLikelihoodTime / totalCalculateLikelihoodCount);


return sb.toString();
}

Expand Down
27 changes: 19 additions & 8 deletions src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public interface BastaLikelihoodDelegate extends ProcessOnCoalescentIntervalDele
double calculateLikelihood(List<BranchIntervalOperation> branchOperations,
List<TransitionMatrixOperation> matrixOperations,
List<Integer> intervalStarts,
int rootNodeNumber);
int rootNodeNumber, BastaLikelihood likelihood);

default void setPartials(int index, double[] partials) {
throw new RuntimeException("Not yet implemented");
Expand All @@ -81,7 +81,14 @@ default void updatePopulationSizes(int index, double[] sizes, boolean flip) {
List<TransitionMatrixOperation> matrixOperation,
List<Integer> intervalStarts,
int rootNodeNumber,
StructuredCoalescentLikelihoodGradient wrt);
StructuredCoalescentLikelihoodGradient wrt, BastaLikelihood likelihood);

void updateStorage(int maxBufferCount,
int treeNodeCount,
BastaLikelihood likelihood);

int getMaxNumberOfCoalescentIntervals();


abstract class AbstractBastaLikelihoodDelegate extends AbstractModel implements BastaLikelihoodDelegate, Citable {

Expand Down Expand Up @@ -109,7 +116,11 @@ public AbstractBastaLikelihoodDelegate(String name,
this.parallelizationScheme = ParallelizationScheme.NONE;
}

private int getMaxNumberOfCoalescentIntervals(Tree tree) {
public int getMaxNumberOfCoalescentIntervals() {
return maxNumCoalescentIntervals;
}

public int getMaxNumberOfCoalescentIntervals(Tree tree) {
BigFastTreeIntervals intervals = new BigFastTreeIntervals((TreeModel) tree); // TODO fix BFTI to take a Tree
int zeroLengthSampling = 0;
for (int i = 0; i < intervals.getIntervalCount(); ++i) {
Expand Down Expand Up @@ -172,7 +183,7 @@ enum Mode {
abstract protected void computeBranchIntervalOperations(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations,
List<TransitionMatrixOperation> matrixOperations,
Mode mode);
Mode mode, BastaLikelihood likelihood);

abstract protected void computeTransitionProbabilityOperations(List<TransitionMatrixOperation> matrixOperations,
Mode mode);
Expand All @@ -187,7 +198,7 @@ abstract protected void computeCoalescentIntervalReduction(List<Integer> interva
public double calculateLikelihood(List<BranchIntervalOperation> branchOperations,
List<TransitionMatrixOperation> matrixOperation,
List<Integer> intervalStarts,
int rootNodeNumber) {
int rootNodeNumber, BastaLikelihood likelihood) {

if (PRINT_COMMANDS) {
System.err.println("Tree = " + tree);
Expand All @@ -199,7 +210,7 @@ public double calculateLikelihood(List<BranchIntervalOperation> branchOperations
while (!done) {

computeTransitionProbabilityOperations(matrixOperation, Mode.LIKELIHOOD);
computeBranchIntervalOperations(intervalStarts, branchOperations, matrixOperation, Mode.LIKELIHOOD);
computeBranchIntervalOperations(intervalStarts, branchOperations, matrixOperation, Mode.LIKELIHOOD, likelihood);

computeCoalescentIntervalReduction(intervalStarts, branchOperations, logL,
Mode.LIKELIHOOD, null);
Expand All @@ -221,7 +232,7 @@ public double[] calculateGradient(List<BranchIntervalOperation> branchOperations
List<TransitionMatrixOperation> matrixOperations,
List<Integer> intervalStarts,
int rootNodeNumber,
StructuredCoalescentLikelihoodGradient wrt) {
StructuredCoalescentLikelihoodGradient wrt, BastaLikelihood likelihood) {
if (PRINT_COMMANDS) {
System.err.println("Tree = " + tree);
}
Expand All @@ -232,7 +243,7 @@ public double[] calculateGradient(List<BranchIntervalOperation> branchOperations
computeTransitionProbabilityOperations(matrixOperations, Mode.GRADIENT);
}

computeBranchIntervalOperations(intervalStarts, branchOperations, matrixOperations, Mode.GRADIENT);
computeBranchIntervalOperations(intervalStarts, branchOperations, matrixOperations, Mode.GRADIENT, likelihood);

double[] gradient = new double[wrt.getIntermediateGradientDimension()];

Expand Down
137 changes: 128 additions & 9 deletions src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dr.evomodel.coalescent.basta;

import beagle.Beagle;
import beagle.BeagleFlag;
import beagle.*;
import beagle.basta.BeagleBasta;
import beagle.basta.BastaFactory;
import dr.evolution.tree.Tree;
Expand All @@ -13,8 +12,10 @@

import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;

import static beagle.basta.BeagleBasta.BASTA_OPERATION_SIZE;
import static dr.evomodel.treedatalikelihood.BeagleFunctionality.parseSystemPropertyIntegerArray;

/**
* @author Marc A. Suchard
Expand All @@ -27,30 +28,134 @@ public class BeagleBastaLikelihoodDelegate extends BastaLikelihoodDelegate.Abstr

private final BufferIndexHelper eigenBufferHelper;
private final OffsetBufferIndexHelper populationSizesBufferHelper;
private static final String RESOURCE_ORDER_PROPERTY = "beagle.resource.order";
private static final String PREFERRED_FLAGS_PROPERTY = "beagle.preferred.flags";
private static final String REQUIRED_FLAGS_PROPERTY = "beagle.required.flags";
int currentPartialsCount;
int currentIntervalsCount;
private static int instanceCount = 0;
private static List<Integer> resourceOrder = null;
private static List<Integer> preferredOrder = null;
private static List<Integer> requiredOrder = null;
private int currentOutputBuffer;
private int maxOutputBuffer;
private boolean updateStorage;


public BeagleBastaLikelihoodDelegate(String name,
Tree tree,
int stateCount,
boolean transpose) {
super(name, tree, stateCount, transpose);

int partialsCount = maxNumCoalescentIntervals * (tree.getNodeCount() + 1); // TODO much too large
int matricesCount = maxNumCoalescentIntervals; // TODO much too small (except for strict-clock)
this.currentPartialsCount = 3 * tree.getNodeCount();
this.currentIntervalsCount = tree.getNodeCount();

int coalescentBufferCount = 5; // E, F, G, H, probabilities
if (resourceOrder == null) {
resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY);
}
if (preferredOrder == null) {
preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY);
}
if (requiredOrder == null) {
requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY);
}

long requirementFlags = 0L;
requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask();

int[] resourceList = null;
long preferenceFlags = 0;

if (resourceOrder.size() > 0) {
// added the zero on the end so that a CPU is selected if requested resource fails
resourceList = new int[]{resourceOrder.get(instanceCount % resourceOrder.size()), 0};
if (resourceList[0] > 0) {
preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU
}
}

if (preferredOrder.size() > 0) {
preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size());
}

if (requiredOrder.size() > 0) {
requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size());
}


if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) {
// if single precision not explicitly set then prefer double
preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask();
}

if ((resourceList == null &&
(BeagleFlag.PROCESSOR_GPU.isSet(preferenceFlags) ||
BeagleFlag.FRAMEWORK_CUDA.isSet(preferenceFlags) ||
BeagleFlag.FRAMEWORK_OPENCL.isSet(preferenceFlags)))
||
(resourceList != null && resourceList[0] > 0)) {
// non-CPU implementations don't have SSE so remove default preference for SSE
// when using non-CPU preferences or prioritising non-CPU resource
preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask();
preferenceFlags &= ~BeagleFlag.THREADING_CPP.getMask();
}

beagle = BastaFactory.loadBastaInstance(0, coalescentBufferCount, maxNumCoalescentIntervals,
partialsCount, 0, stateCount,
1, 2, matricesCount, 1,
1, null, 0L, requirementFlags);
currentPartialsCount, 0, stateCount,
1, 2, currentIntervalsCount, 1,
1, resourceList, preferenceFlags, requirementFlags);

eigenBufferHelper = new BufferIndexHelper(1, 0);
populationSizesBufferHelper = new OffsetBufferIndexHelper(1, 0, 0);

beagle.setCategoryRates(new double[] { 1.0 });

final Logger logger = Logger.getLogger("dr.evomodel");
InstanceDetails instanceDetails = beagle.getDetails();
ResourceDetails resourceDetails = null;

if (instanceDetails != null) {
resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber());
if (resourceDetails != null) {
StringBuilder sb = new StringBuilder(" Using BEAGLE BASTA resource ");
sb.append(resourceDetails.getNumber()).append(": ");
sb.append(resourceDetails.getName()).append("\n");
if (resourceDetails.getDescription() != null) {
String[] description = resourceDetails.getDescription().split("\\|");
for (String desc : description) {
if (desc.trim().length() > 0) {
sb.append(" ").append(desc.trim()).append("\n");
}
}
}
sb.append(" with instance flags: ").append(instanceDetails.toString());
logger.info(sb.toString());
} else {
logger.info(" Error retrieving BEAGLE resource for instance: " + instanceDetails.toString());
}
} else {
logger.info(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation");
}
}


public void resize(int newNumPartials, int newNumCoalescentIntervals) {
updateStorage = false;
if (newNumPartials > currentPartialsCount) {
this.currentPartialsCount = newNumPartials + 1;
updateStorage = true;
}

if (newNumCoalescentIntervals > currentIntervalsCount) {
this.currentIntervalsCount = newNumCoalescentIntervals;
updateStorage = true;
}

if (updateStorage) {
beagle.allocateCoalescentBuffers(5, currentIntervalsCount, currentPartialsCount, 0);
}
}

@Override
Expand All @@ -62,14 +167,14 @@ protected void allocateGradientMemory() {
protected void computeBranchIntervalOperations(List<Integer> intervalStarts,
List<BranchIntervalOperation> branchIntervalOperations,
List<TransitionMatrixOperation> matrixOperations,
Mode mode) {
Mode mode, BastaLikelihood likelihood) {

int[] operations = new int[branchIntervalOperations.size() * BASTA_OPERATION_SIZE]; // TODO instantiate once
int[] intervals = new int[intervalStarts.size()]; // TODO instantiate once
double[] lengths = new double[intervalStarts.size() - 1]; // TODO instantiate once

vectorizeBranchIntervalOperations(intervalStarts, branchIntervalOperations, operations, intervals, lengths);

updateStorage(maxOutputBuffer, maxNumCoalescentIntervals, likelihood);
int populationSizeIndex = populationSizesBufferHelper.getOffsetIndex(0);

if (mode == Mode.LIKELIHOOD) {
Expand Down Expand Up @@ -228,6 +333,15 @@ public void updatePopulationSizes(int index, double[] sizes, boolean flip) {
beagle.setStateFrequencies(populationSizesBufferHelper.getOffsetIndex(0), sizes);
}

@Override
public void updateStorage(int maxBufferCount, int treeNodeCount, BastaLikelihood likelihood) {
int newNumPartials = maxBufferCount + 1;
resize(newNumPartials, maxNumCoalescentIntervals);
if (likelihood != null && updateStorage) {
likelihood.setTipData();
}
}

private void vectorizeTransitionMatrixOperations(List<TransitionMatrixOperation> matrixOperations,
int[] transitionMatrixIndices,
double[] branchLengths) {
Expand Down Expand Up @@ -296,6 +410,11 @@ private void vectorizeBranchIntervalOperations(List<Integer> intervalStarts,
operations[k + 6] = op.accBuffer2;
operations[k + 7] = op.intervalNumber;
}
currentOutputBuffer = operations[k + 6];

if (currentOutputBuffer > maxOutputBuffer) {
maxOutputBuffer = currentOutputBuffer;
}

k += BASTA_OPERATION_SIZE;
}
Expand Down
Loading

0 comments on commit 824a635

Please sign in to comment.