diff --git a/setup.py b/setup.py index 45c157da..d98e9376 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires.append(line) setup(name="weaver-core", - version='0.4.9', + version='0.4.10', description="A streamlined deep-learning framework for high energy physics", long_description_content_type="text/markdown", author="H. Qu, C. Li", diff --git a/weaver/utils/data/tools.py b/weaver/utils/data/tools.py index ab308efa..5c57bd30 100644 --- a/weaver/utils/data/tools.py +++ b/weaver/utils/data/tools.py @@ -83,26 +83,17 @@ def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1): return batch_knn_idx -def _batch_permute_indices(array, maxlen): - batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1)) - for i, a in enumerate(array): - batch_permute_idx[i, :len(a)] = np.random.permutation(len(a[:maxlen])) - return batch_permute_idx +def _batch_permute_indices(array): + random_array = ak.unflatten(np.random.rand(ak.count(array)), ak.num(array)) + return ak.argsort(random_array) -def _batch_argsort(array, maxlen): - batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1)) - for i, a in enumerate(array): - batch_argsort_idx[i, :len(a)] = np.argsort(a[:maxlen]) - return batch_argsort_idx +def _batch_argsort(array): + return ak.argsort(array) def _batch_gather(array, indices): - out = array.zeros_like() - for i, (a, idx) in enumerate(zip(array, indices)): - maxlen = min(len(a), len(idx)) - out[i][:maxlen] = a[idx[:maxlen]] - return out + return array[indices] def _p4_from_pxpypze(px, py, pz, energy):