Skip to content

Commit

Permalink
Closes #4075: MergeShuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Feb 12, 2025
1 parent 220efdb commit 06edaf9
Show file tree
Hide file tree
Showing 4 changed files with 592 additions and 8 deletions.
3 changes: 2 additions & 1 deletion arkouda/numpy/random/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def standard_normal(self, size=None, method="zig"):
self._state += full_size * 2
return create_pdarray(rep_msg)

def shuffle(self, x):
def shuffle(self, x, method="FisherYates"):
"""
Randomly shuffle a pdarray in place.
Expand All @@ -605,6 +605,7 @@ def shuffle(self, x):
"x": x,
"shape": x.shape,
"state": self._state,
"method": method.lower(),
},
)
self._state += x.size
Expand Down
249 changes: 247 additions & 2 deletions src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,18 @@ module RandMsg
}

@arkouda.instantiateAndRegister
proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
do return shuffleHelp(cmd, msgArgs, st, array_dtype, array_nd);
proc shuffle(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws{
const method = msgArgs["method"].toScalar(string);
if (method == "mergeshuffle"){
return mergeShuffleHelp(cmd, msgArgs, st, array_dtype, array_nd);
} else if (method == "fisheryates"){
return shuffleHelp(cmd, msgArgs, st, array_dtype, array_nd);
}else{
const errorMsg = "Error: Invalid method for shuffle. Allowed values: fisheryates, mergeshuffle";
return MsgTuple.error(errorMsg);
}

}

proc shuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd == 1
Expand Down Expand Up @@ -733,6 +743,241 @@ module RandMsg
return MsgTuple.success();
}

proc mergeShuffleHelp(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab, type array_dtype, param array_nd: int): MsgTuple throws
where array_nd == 1
{
const name = msgArgs["name"],
xName = msgArgs["x"].toScalar(string),
shape = msgArgs["shape"].toScalarTuple(int, array_nd),
state = msgArgs["state"].toScalar(int);

randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"name: %? shape %? dtype: %? state %i".format(name, shape, type2str(array_dtype), state));

var generatorEntry = st[name]: borrowed GeneratorSymEntry(int);
ref rng = generatorEntry.generator;

if state != 1 then rng.skipTo(state-1);
const generatorSeed = (rng.next() * 2**62):int;

const arrEntry = st[xName]: SymEntry(array_dtype, array_nd);
ref myArr = arrEntry.a;
mergeShuffle(myArr, generatorSeed);
return MsgTuple.success();
}

proc mergeShuffle(ref x: [] ?t, generatorSeed: int): int throws {
const numRounds = log2(numLocales) + 1;
const domainLows = getDomainLows(x);
const domainHighs = getDomainHighs(x);

// TODO: Update all the seed calculations to match the final version of the algorithm. Please ignore for now.
var seed = 1;
shuffleLocales(x, generatorSeed);

for m in 0..#numRounds {
const maxLocalesPerPrevChunk = 2**m;
const numNewChunks = (numLocales - 1) / (2 * maxLocalesPerPrevChunk) + 1;

forall i in 0..#numNewChunks with (ref x, ref generatorSeed){
const startLocale = 2 * i * maxLocalesPerPrevChunk;
const endLocale = min(startLocale + maxLocalesPerPrevChunk - 1, numLocales - 1 );
const startLocale2 = min(endLocale + 1, numLocales - 1 );
const endLocale2 = min(startLocale2 + maxLocalesPerPrevChunk - 1, numLocales - 1 );
if( endLocale < startLocale2 ){
const start = domainLows[startLocale];
const size1 = domainHighs[endLocale] - domainLows[startLocale] + 1;
const size2 = domainHighs[endLocale2] - domainLows[startLocale2] + 1;

const taskSeed = seed + i * 2 * numLocales; // TODO: Ignore for now, needs to be updated
merge(x, start, size1, size2, taskSeed);
}
}
seed += numNewChunks * 2 * numLocales;
}
return seed;
}

/*
Shuffles each locale of the array independently.
There should be no communication between locales for this step.
*/
proc shuffleLocales(ref x: [] ?t, const generatorSeed: int): int {
coforall loc in Locales with (ref x, const generatorSeed) do on loc{
var randStreamInt = new randomStream(int, seed=(generatorSeed + here.id));
var seed = randStreamInt.next();

const localLower = x.localSubdomain(loc=here).low;
const localUpper = x.localSubdomain(loc=here).high;
const size = localUpper - localLower + 1;

const maxFisherYatesPower = 6; // Hardcoded
const smallestChunkSize = max(size/(2**maxFisherYatesPower), min(10, size)); // Hardcoded minimum size
const numChunks = (size - 1)/smallestChunkSize + 1;

forall i in 0..#numChunks with (ref x, const seed){
const taskSeed = seed + i;
const low = localLower + i * smallestChunkSize;
const high = min(localUpper, low + smallestChunkSize - 1);
fisherYatesOnLocale(x, low , high, high, true, taskSeed);
}

seed += numChunks;

const numRounds = log2(numChunks) + 1;

for m in 0..#(numRounds) {
const prevChunkSize = smallestChunkSize * 2**m;
const newChunkSize = 2 * prevChunkSize;
const numNewChunks = (size - 1)/newChunkSize + 1;

forall i in 0..#(numNewChunks) with (ref x, const localLower, const localUpper, const newChunkSize, const prevChunkSize, const seed){

const start = localLower + i * newChunkSize;
const size1 = min(localUpper - start + 1, prevChunkSize);
const size2 = min(localUpper - start - size1 + 1, prevChunkSize);

const taskSeed = seed + i;
mergeOnLocale(x, start, size1, size2, taskSeed);

}
seed += numNewChunks;
}
}
return generatorSeed + 1;
}

/*
Shuffle the slice of array over index lower..upper.
*/
proc shuffleRange(ref x: [] ?t, lower: int, upper: int, generatorSeed: int): int {
return shuffleRange(x, lower, upper, upper, true, generatorSeed);
}

proc shuffleRange(ref x: [] ?t, const lower: int, const upper: int, const bound: int, const isUpperBound: bool, const generatorSeed: int): int {
for loc in Locales do on loc {

const localLower = max(x.localSubdomain(loc=here).low, lower);
const localUpper = min(x.localSubdomain(loc=here).high, upper);

fisherYatesOnLocale(x, localLower, localUpper, bound, isUpperBound, generatorSeed);
}
return generatorSeed + numLocales;
}

/*
Conduct partial Fisher Yates shuffle on the slice of the array over index localLower..localUpper.
Note that elements in this interval can be swapped with elements in range localLower..upper.
*/
proc fisherYatesOnLocale(ref x: [] ?t, const localLower: int, const localUpper: int, const bound: int, const isUpperBound: bool, const generatorSeed: int): int {

var randStreamInt = new randomStream(int, seed=generatorSeed);
for i in localLower..localUpper {
const idx = if isUpperBound then randStreamInt.choose(i..bound) else randStreamInt.choose(bound..i);
if (i != idx ){
x[i] <=> x[idx];
}
}
return generatorSeed + 1;
}

proc mergeOnLocale(ref x: [] ?t, s: int, n1: int, n2: int, generatorSeed: int){
var i: int = s;
var j: int = s + n1;
var n: int = s + n1 + n2 - 1;
const threshold = (n1: real)/((n1 + n2): real);

var seed = generatorSeed + here.id;
var randStream = new randomStream(real, seed=seed);

while(true){
if randStream.next() < threshold {
if (i==j){
break;
}
} else {
if (j==n) {
break;
}
x[i] <=> x[j];
j += 1;
}
i += 1;
}

fisherYatesOnLocale(x, i, n, s, false, seed);
}

/*
This version does the swaps using chapel <=> from on the local where the first element lives
*/
proc merge(ref x: [] ?t, s: int, n1: int, n2: int, generatorSeed: int): int {
var i: int = s;
var j: int = s + n1;
var n: int = s + n1 + n2 - 1;
const threshold = (n1: real)/((n1 + n2): real);

for loc in Locales do on loc {
const low = x.localSubdomain(loc=here).low;
const high = x.localSubdomain(loc=here).high;

const seed = generatorSeed + here.id;
var randStream = new randomStream(real, seed=seed);

while(true){
if (i < low) | (i > high){
break;
}

if randStream.next() < threshold {
if (i==j){
break;
}
} else {
if (j==n) {
break;
}
x[i] <=> x[j];
j += 1;
}
i += 1;
}
}
var seed = generatorSeed + numLocales;
seed = shuffleRange(x, i, n, s, false, seed);

return seed;
}

proc newSeed(seed: int): int {
var randStream = new randomStream(int, seed=seed);
return randStream.next();
}

proc getDomainLows(ref x: [] ?t): [] int {
var domainLows: [0..#numLocales] int;
coforall loc in Locales with (ref x) do on loc {
domainLows[here.id] = x.localSubdomain(loc=here).low;
}
return domainLows;
}

proc getDomainHighs(ref x: [] ?t): [] int {
var domainHighs: [0..#numLocales] int;
coforall loc in Locales with (ref x) do on loc {
domainHighs[here.id] = x.localSubdomain(loc=here).high;
}
return domainHighs;
}

proc getDomainSizes(ref x: [] ?t): [] int {
var domainSizes: [0..#numLocales] int;
coforall loc in Locales with (ref x) do on loc {
domainSizes[here.id] = x.localSubdomain(loc=here).size;
}
return domainSizes;
}

use CommandMap;
registerFunction("logisticGenerator", logisticGeneratorMsg, getModuleName());
registerFunction("segmentedSample", segmentedSampleMsg, getModuleName());
Expand Down
12 changes: 7 additions & 5 deletions tests/numpy/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_integers(self):
assert all(bounded_arr.to_ndarray() < 5)

@pytest.mark.parametrize("data_type", INT_FLOAT)
def test_shuffle(self, data_type):
@pytest.mark.parametrize("method", ["FisherYates", "MergeShuffle"])
@pytest.mark.parametrize("size", pytest.prob_size)
def test_shuffle(self, data_type, method, size):

# ints are checked for equality; floats are checked for closeness

Expand All @@ -71,18 +73,18 @@ def test_shuffle(self, data_type):

rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda = rnfunc(-(2**32), 2**32, 10)
pda = rnfunc(-(2**32), 2**32, size)
pda_copy = pda[:]
rng.shuffle(pda)
rng.shuffle(pda, method=method)

assert check(ak.sort(pda), ak.sort(pda_copy), data_type)

# verify same seed gives reproducible arrays

rng = ak.random.default_rng(18)
rnfunc = rng.integers if data_type is ak.int64 else rng.uniform
pda_prime = rnfunc(-(2**32), 2**32, 10)
rng.shuffle(pda_prime)
pda_prime = rnfunc(-(2**32), 2**32, size)
rng.shuffle(pda_prime, method=method)

assert check(pda, pda_prime, data_type)

Expand Down
Loading

0 comments on commit 06edaf9

Please sign in to comment.