Skip to content

Commit

Permalink
Parallizes InterRDF_s
Browse files Browse the repository at this point in the history
  • Loading branch information
tanishy7777 committed Jan 18, 2025
1 parent 2c919ea commit ea6d4ae
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 108 deletions.
103 changes: 30 additions & 73 deletions package/MDAnalysis/analysis/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _get_aggregator(self):
return ResultsGroup(lookup={'count': ResultsGroup.ndarray_sum,
'volume_cum': ResultsGroup.ndarray_sum,
'bins': ResultsGroup.ndarray_sum,
'edges': ResultsGroup.ndarray_sum})
'edges': ResultsGroup.ndarray_mean})

def _conclude(self):
norm = self.n_frames
Expand Down Expand Up @@ -583,17 +583,42 @@ def get_supported_backends(cls):

_analysis_algorithm_is_parallelizable = True


@staticmethod
def func(arrs):
r"""Custom aggregator for nested arrays
Parameters
----------
arrs : list
List of arrays or nested lists of arrays
Returns
-------
ndarray
Sums flattened arrays at alternate index
in the list and returns a list of two arrays
"""
def flatten(arr):
if isinstance(arr, (list, tuple)):
return [item for sublist in arr for item in flatten(sublist)]
return [arr]

flat = flatten(arrs)
aggregated_arr = [np.zeros_like(flat[0]), np.zeros_like(flat[1])]
for i in range(len(flat)//2):
aggregated_arr[0] += flat[2*i] # 0, 2, 4, ...
aggregated_arr[1] += flat[2*i+1] # 1, 3, 5, ...
return aggregated_arr

def _get_aggregator(self):
return ResultsGroup(
lookup={
'count': self._flattened_ndarray_sum,
'count': self.func,
'volume_cum': ResultsGroup.ndarray_sum,
'bins': ResultsGroup.ndarray_sum,
'bins': ResultsGroup.ndarray_mean,
'edges': ResultsGroup.ndarray_mean,
}
)


def __init__(
self,
Expand Down Expand Up @@ -676,74 +701,6 @@ def arr_resize(arr):
else:
raise ValueError("Array has an invalid shape")

Check warning on line 702 in package/MDAnalysis/analysis/rdf.py

View check run for this annotation

Codecov / codecov/patch

package/MDAnalysis/analysis/rdf.py#L702

Added line #L702 was not covered by tests

# @staticmethod
# def custom_aggregate(combined_arr):
# arr1 = combined_arr[0][0]
# arr2 = combined_arr[1][0]
# arr3 = combined_arr[1][1][0]
# arr4 = combined_arr[1][1][1]

# arr1 = InterRDF_s.arr_resize(arr1)
# arr2 = InterRDF_s.arr_resize(arr2)
# arr3 = InterRDF_s.arr_resize(arr3)
# arr4 = InterRDF_s.arr_resize(arr4)


# print(arr1.shape, arr2.shape, arr3.shape, arr4.shape)



# arr01 = arr1 + arr2
# arr02 = np.vstack((arr3, arr4))
# print("New shape", arr01.shape, arr02.shape)

# arr = [arr01, arr02]
# # arr should be [(1,2,75), (2,2,75)]
# return arr


# #TODO: check shapes without parallelization then emulate that in custom_aggregate

# def _get_aggregator(self):
# return ResultsGroup(lookup={'count': self.custom_aggregate,
# 'volume_cum': ResultsGroup.ndarray_sum,
# 'bins': ResultsGroup.ndarray_sum,
# 'edges': ResultsGroup.ndarray_sum})

@staticmethod
def _flattened_ndarray_sum(arrs):
"""Custom aggregator for nested count arrays
Parameters
----------
arrs : list
List of arrays or nested lists of arrays to sum
Returns
-------
ndarray
Sum of all arrays after flattening nested structure
"""
# Handle nested list/array structures
def flatten(arr):
if isinstance(arr, (list, tuple)):
return [item for sublist in arr for item in flatten(sublist)]
return [arr]

# Flatten and sum arrays
flat = flatten(arrs)
if not flat:
return None

f1 = np.zeros_like(flat[0])
f2 = np.zeros_like(flat[1])
# print(flat, "SIZE:", len(flat))
for i in range(len(flat)//2):
f1 += flat[2*i]
f2 += flat[2*i+1]
array1 = [f1, f2]
# print("ARRAY", array1)
return array1


def _conclude(self):
Expand Down
52 changes: 17 additions & 35 deletions testsuite/MDAnalysisTests/analysis/test_rdf_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def sels(u):


@pytest.fixture(scope="module")
def rdf(u, sels):
return InterRDF_s(u, sels).run()
def rdf(u, sels, client_InterRDF_s):
r = InterRDF_s(u, sels).run(**client_InterRDF_s)
return r


def test_nbins(u, sels):
rdf = InterRDF_s(u, sels, nbins=412).run()
def test_nbins(u, sels, client_InterRDF_s):
rdf = InterRDF_s(u, sels, nbins=412).run(**client_InterRDF_s)

assert len(rdf.results.bins) == 412


def test_range(u, sels):
def test_range(u, sels, client_InterRDF_s):
rmin, rmax = 1.0, 13.0
rdf = InterRDF_s(u, sels, range=(rmin, rmax)).run()
rdf = InterRDF_s(u, sels, range=(rmin, rmax)).run(**client_InterRDF_s)

assert rdf.results.edges[0] == rmin
assert rdf.results.edges[-1] == rmax
Expand Down Expand Up @@ -113,19 +113,19 @@ def test_cdf(rdf):
(True, 0.021915460340071267),
],
)
def test_density(u, sels, density, value):
def test_density(u, sels, density, value, client_InterRDF_s):
kwargs = {"density": density} if density is not None else {}
rdf = InterRDF_s(u, sels, **kwargs).run()
rdf = InterRDF_s(u, sels, **kwargs).run(**client_InterRDF_s)
print(rdf.results.rdf[0][0][0], "RDF")
assert_almost_equal(max(rdf.results.rdf[0][0][0]), value)
if not density:
s1 = u.select_atoms("name ZND and resid 289")
s2 = u.select_atoms(
"name OD1 and resid 51 and sphzone 5.0 (resid 289)"
)
rdf_ref = InterRDF(s1, s2).run()
rdf_ref = InterRDF(s1, s2).run(**client_InterRDF_s)
assert_almost_equal(rdf_ref.results.rdf, rdf.results.rdf[0][0][0])


def test_overwrite_norm(u, sels):
rdf = InterRDF_s(u, sels, norm="rdf", density=True)
assert rdf.norm == "density"
Expand All @@ -139,23 +139,23 @@ def test_overwrite_norm(u, sels):
("none", 0.6),
],
)
def test_norm(u, sels, norm, value):
rdf = InterRDF_s(u, sels, norm=norm).run()
def test_norm(u, sels, norm, value, client_InterRDF_s):
rdf = InterRDF_s(u, sels, norm=norm).run(**client_InterRDF_s)
assert_allclose(max(rdf.results.rdf[0][0][0]), value)
if norm == "rdf":
s1 = u.select_atoms("name ZND and resid 289")
s2 = u.select_atoms(
"name OD1 and resid 51 and sphzone 5.0 (resid 289)"
)
rdf_ref = InterRDF(s1, s2).run()
rdf_ref = InterRDF(s1, s2).run(**client_InterRDF_s)
assert_almost_equal(rdf_ref.results.rdf, rdf.results.rdf[0][0][0])


@pytest.mark.parametrize(
"norm, norm_required", [("Density", "density"), (None, "none")]
)
def test_norm_values(u, sels, norm, norm_required):
rdf = InterRDF_s(u, sels, norm=norm).run()
def test_norm_values(u, sels, norm, norm_required, client_InterRDF_s):
rdf = InterRDF_s(u, sels, norm=norm).run(**client_InterRDF_s)
assert rdf.norm == norm_required


Expand All @@ -170,22 +170,4 @@ def test_rdf_attr_warning(rdf, attr):
rdf.get_cdf()
wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0"
with pytest.warns(DeprecationWarning, match=wmsg):
getattr(rdf, attr) is rdf.results[attr]

@pytest.mark.parametrize(
"classname,is_parallelizable",
[
(mda.analysis.rdf, False),
]
)
def test_class_is_parallelizable(classname, is_parallelizable):
assert classname.InterRDF_s._analysis_algorithm_is_parallelizable == is_parallelizable

@pytest.mark.parametrize(
"classname,backends",
[
(mda.analysis.rdf, ('serial',)),
]
)
def test_supported_backends(classname, backends):
assert classname.InterRDF_s.get_supported_backends() == backends
getattr(rdf, attr) is rdf.results[attr]

0 comments on commit ea6d4ae

Please sign in to comment.