Skip to content

Commit

Permalink
Better implementations of preprocess functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Jan 25, 2024
1 parent a16973f commit 7be2906
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.4.9',
version='0.4.10',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
21 changes: 6 additions & 15 deletions weaver/utils/data/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,17 @@ def _batch_knn(supports, queries, k, maxlen_s, maxlen_q=None, n_jobs=1):
return batch_knn_idx


def _batch_permute_indices(array, maxlen):
batch_permute_idx = np.tile(np.arange(maxlen), (len(array), 1))
for i, a in enumerate(array):
batch_permute_idx[i, :len(a)] = np.random.permutation(len(a[:maxlen]))
return batch_permute_idx
def _batch_permute_indices(array):
random_array = ak.unflatten(np.random.rand(ak.count(array)), ak.num(array))
return ak.argsort(random_array)


def _batch_argsort(array, maxlen):
batch_argsort_idx = np.tile(np.arange(maxlen), (len(array), 1))
for i, a in enumerate(array):
batch_argsort_idx[i, :len(a)] = np.argsort(a[:maxlen])
return batch_argsort_idx
def _batch_argsort(array):
return ak.argsort(array)


def _batch_gather(array, indices):
out = array.zeros_like()
for i, (a, idx) in enumerate(zip(array, indices)):
maxlen = min(len(a), len(idx))
out[i][:maxlen] = a[idx[:maxlen]]
return out
return array[indices]


def _p4_from_pxpypze(px, py, pz, energy):
Expand Down

0 comments on commit 7be2906

Please sign in to comment.