From ea6d4ae284147c1db0878dca483a8f1719319c5d Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sat, 18 Jan 2025 22:20:12 +0530 Subject: [PATCH] Parallizes InterRDF_s --- package/MDAnalysis/analysis/rdf.py | 103 +++++------------- .../MDAnalysisTests/analysis/test_rdf_s.py | 52 +++------ 2 files changed, 47 insertions(+), 108 deletions(-) diff --git a/package/MDAnalysis/analysis/rdf.py b/package/MDAnalysis/analysis/rdf.py index 511e0495f1..4a13152738 100644 --- a/package/MDAnalysis/analysis/rdf.py +++ b/package/MDAnalysis/analysis/rdf.py @@ -316,7 +316,7 @@ def _get_aggregator(self): return ResultsGroup(lookup={'count': ResultsGroup.ndarray_sum, 'volume_cum': ResultsGroup.ndarray_sum, 'bins': ResultsGroup.ndarray_sum, - 'edges': ResultsGroup.ndarray_sum}) + 'edges': ResultsGroup.ndarray_mean}) def _conclude(self): norm = self.n_frames @@ -583,17 +583,42 @@ def get_supported_backends(cls): _analysis_algorithm_is_parallelizable = True - + @staticmethod + def func(arrs): + r"""Custom aggregator for nested arrays + + Parameters + ---------- + arrs : list + List of arrays or nested lists of arrays + + Returns + ------- + ndarray + Sums flattened arrays at alternate index + in the list and returns a list of two arrays + """ + def flatten(arr): + if isinstance(arr, (list, tuple)): + return [item for sublist in arr for item in flatten(sublist)] + return [arr] + + flat = flatten(arrs) + aggregated_arr = [np.zeros_like(flat[0]), np.zeros_like(flat[1])] + for i in range(len(flat)//2): + aggregated_arr[0] += flat[2*i] # 0, 2, 4, ... + aggregated_arr[1] += flat[2*i+1] # 1, 3, 5, ... + return aggregated_arr + def _get_aggregator(self): return ResultsGroup( lookup={ - 'count': self._flattened_ndarray_sum, + 'count': self.func, 'volume_cum': ResultsGroup.ndarray_sum, - 'bins': ResultsGroup.ndarray_sum, + 'bins': ResultsGroup.ndarray_mean, 'edges': ResultsGroup.ndarray_mean, } ) - def __init__( self, @@ -676,74 +701,6 @@ def arr_resize(arr): else: raise ValueError("Array has an invalid shape") - # @staticmethod - # def custom_aggregate(combined_arr): - # arr1 = combined_arr[0][0] - # arr2 = combined_arr[1][0] - # arr3 = combined_arr[1][1][0] - # arr4 = combined_arr[1][1][1] - - # arr1 = InterRDF_s.arr_resize(arr1) - # arr2 = InterRDF_s.arr_resize(arr2) - # arr3 = InterRDF_s.arr_resize(arr3) - # arr4 = InterRDF_s.arr_resize(arr4) - - - # print(arr1.shape, arr2.shape, arr3.shape, arr4.shape) - - - - # arr01 = arr1 + arr2 - # arr02 = np.vstack((arr3, arr4)) - # print("New shape", arr01.shape, arr02.shape) - - # arr = [arr01, arr02] - # # arr should be [(1,2,75), (2,2,75)] - # return arr - - - # #TODO: check shapes without parallelization then emulate that in custom_aggregate - - # def _get_aggregator(self): - # return ResultsGroup(lookup={'count': self.custom_aggregate, - # 'volume_cum': ResultsGroup.ndarray_sum, - # 'bins': ResultsGroup.ndarray_sum, - # 'edges': ResultsGroup.ndarray_sum}) - - @staticmethod - def _flattened_ndarray_sum(arrs): - """Custom aggregator for nested count arrays - - Parameters - ---------- - arrs : list - List of arrays or nested lists of arrays to sum - - Returns - ------- - ndarray - Sum of all arrays after flattening nested structure - """ - # Handle nested list/array structures - def flatten(arr): - if isinstance(arr, (list, tuple)): - return [item for sublist in arr for item in flatten(sublist)] - return [arr] - - # Flatten and sum arrays - flat = flatten(arrs) - if not flat: - return None - - f1 = np.zeros_like(flat[0]) - f2 = np.zeros_like(flat[1]) - # print(flat, "SIZE:", len(flat)) - for i in range(len(flat)//2): - f1 += flat[2*i] - f2 += flat[2*i+1] - array1 = [f1, f2] - # print("ARRAY", array1) - return array1 def _conclude(self): diff --git a/testsuite/MDAnalysisTests/analysis/test_rdf_s.py b/testsuite/MDAnalysisTests/analysis/test_rdf_s.py index bdce3eed64..e209ea2102 100644 --- a/testsuite/MDAnalysisTests/analysis/test_rdf_s.py +++ b/testsuite/MDAnalysisTests/analysis/test_rdf_s.py @@ -48,19 +48,19 @@ def sels(u): @pytest.fixture(scope="module") -def rdf(u, sels): - return InterRDF_s(u, sels).run() +def rdf(u, sels, client_InterRDF_s): + r = InterRDF_s(u, sels).run(**client_InterRDF_s) + return r -def test_nbins(u, sels): - rdf = InterRDF_s(u, sels, nbins=412).run() +def test_nbins(u, sels, client_InterRDF_s): + rdf = InterRDF_s(u, sels, nbins=412).run(**client_InterRDF_s) assert len(rdf.results.bins) == 412 - -def test_range(u, sels): +def test_range(u, sels, client_InterRDF_s): rmin, rmax = 1.0, 13.0 - rdf = InterRDF_s(u, sels, range=(rmin, rmax)).run() + rdf = InterRDF_s(u, sels, range=(rmin, rmax)).run(**client_InterRDF_s) assert rdf.results.edges[0] == rmin assert rdf.results.edges[-1] == rmax @@ -113,19 +113,19 @@ def test_cdf(rdf): (True, 0.021915460340071267), ], ) -def test_density(u, sels, density, value): +def test_density(u, sels, density, value, client_InterRDF_s): kwargs = {"density": density} if density is not None else {} - rdf = InterRDF_s(u, sels, **kwargs).run() + rdf = InterRDF_s(u, sels, **kwargs).run(**client_InterRDF_s) + print(rdf.results.rdf[0][0][0], "RDF") assert_almost_equal(max(rdf.results.rdf[0][0][0]), value) if not density: s1 = u.select_atoms("name ZND and resid 289") s2 = u.select_atoms( "name OD1 and resid 51 and sphzone 5.0 (resid 289)" ) - rdf_ref = InterRDF(s1, s2).run() + rdf_ref = InterRDF(s1, s2).run(**client_InterRDF_s) assert_almost_equal(rdf_ref.results.rdf, rdf.results.rdf[0][0][0]) - def test_overwrite_norm(u, sels): rdf = InterRDF_s(u, sels, norm="rdf", density=True) assert rdf.norm == "density" @@ -139,23 +139,23 @@ def test_overwrite_norm(u, sels): ("none", 0.6), ], ) -def test_norm(u, sels, norm, value): - rdf = InterRDF_s(u, sels, norm=norm).run() +def test_norm(u, sels, norm, value, client_InterRDF_s): + rdf = InterRDF_s(u, sels, norm=norm).run(**client_InterRDF_s) assert_allclose(max(rdf.results.rdf[0][0][0]), value) if norm == "rdf": s1 = u.select_atoms("name ZND and resid 289") s2 = u.select_atoms( "name OD1 and resid 51 and sphzone 5.0 (resid 289)" ) - rdf_ref = InterRDF(s1, s2).run() + rdf_ref = InterRDF(s1, s2).run(**client_InterRDF_s) assert_almost_equal(rdf_ref.results.rdf, rdf.results.rdf[0][0][0]) @pytest.mark.parametrize( "norm, norm_required", [("Density", "density"), (None, "none")] ) -def test_norm_values(u, sels, norm, norm_required): - rdf = InterRDF_s(u, sels, norm=norm).run() +def test_norm_values(u, sels, norm, norm_required, client_InterRDF_s): + rdf = InterRDF_s(u, sels, norm=norm).run(**client_InterRDF_s) assert rdf.norm == norm_required @@ -170,22 +170,4 @@ def test_rdf_attr_warning(rdf, attr): rdf.get_cdf() wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0" with pytest.warns(DeprecationWarning, match=wmsg): - getattr(rdf, attr) is rdf.results[attr] - -@pytest.mark.parametrize( - "classname,is_parallelizable", - [ - (mda.analysis.rdf, False), - ] -) -def test_class_is_parallelizable(classname, is_parallelizable): - assert classname.InterRDF_s._analysis_algorithm_is_parallelizable == is_parallelizable - -@pytest.mark.parametrize( - "classname,backends", - [ - (mda.analysis.rdf, ('serial',)), - ] -) -def test_supported_backends(classname, backends): - assert classname.InterRDF_s.get_supported_backends() == backends + getattr(rdf, attr) is rdf.results[attr] \ No newline at end of file