Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tanishy7777 committed Jan 18, 2025
1 parent ea6d4ae commit a897ca9
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions package/MDAnalysis/analysis/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ class InterRDF(AnalysisBase):

@classmethod
def get_supported_backends(cls):
return ('serial', 'multiprocessing', 'dask',)
return (
"serial",
"multiprocessing",
"dask",
)

_analysis_algorithm_is_parallelizable = True

Expand Down Expand Up @@ -313,10 +317,14 @@ def _single_frame(self):
self.volume_cum += self._ts.volume

def _get_aggregator(self):
return ResultsGroup(lookup={'count': ResultsGroup.ndarray_sum,
'volume_cum': ResultsGroup.ndarray_sum,
'bins': ResultsGroup.ndarray_sum,
'edges': ResultsGroup.ndarray_mean})
return ResultsGroup(
lookup={
"count": ResultsGroup.ndarray_sum,
"volume_cum": ResultsGroup.ndarray_sum,
"bins": ResultsGroup.ndarray_sum,
"edges": ResultsGroup.ndarray_mean,
}
)

def _conclude(self):
norm = self.n_frames
Expand Down Expand Up @@ -577,49 +585,55 @@ class InterRDF_s(AnalysisBase):
.. deprecated:: 2.3.0
The `universe` parameter is superflous.
"""

@classmethod
def get_supported_backends(cls):
return ('serial', 'multiprocessing', 'dask',)
return (
"serial",
"multiprocessing",
"dask",
)

_analysis_algorithm_is_parallelizable = True

@staticmethod
def func(arrs):
r"""Custom aggregator for nested arrays
Parameters
----------
arrs : list
List of arrays or nested lists of arrays
Returns
-------
ndarray
Sums flattened arrays at alternate index
Sums flattened arrays at alternate index
in the list and returns a list of two arrays
"""

def flatten(arr):
if isinstance(arr, (list, tuple)):
return [item for sublist in arr for item in flatten(sublist)]
return [arr]

flat = flatten(arrs)
aggregated_arr = [np.zeros_like(flat[0]), np.zeros_like(flat[1])]
for i in range(len(flat)//2):
aggregated_arr[0] += flat[2*i] # 0, 2, 4, ...
aggregated_arr[1] += flat[2*i+1] # 1, 3, 5, ...
for i in range(len(flat) // 2):
aggregated_arr[0] += flat[2 * i] # 0, 2, 4, ...
aggregated_arr[1] += flat[2 * i + 1] # 1, 3, 5, ...
return aggregated_arr

def _get_aggregator(self):
return ResultsGroup(
lookup={
'count': self.func,
'volume_cum': ResultsGroup.ndarray_sum,
'bins': ResultsGroup.ndarray_mean,
'edges': ResultsGroup.ndarray_mean,
"count": self.func,
"volume_cum": ResultsGroup.ndarray_sum,
"bins": ResultsGroup.ndarray_mean,
"edges": ResultsGroup.ndarray_mean,
}
)

def __init__(
self,
u,
Expand Down Expand Up @@ -692,17 +706,6 @@ def _single_frame(self):
self.results.volume_cum += self._ts.volume
self.volume_cum += self._ts.volume

@staticmethod
def arr_resize(arr):
if arr.ndim == 2: # If shape is (x, y)
return arr[np.newaxis, ...] # Add a new axis at the beginning
elif arr.ndim == 3 and arr.shape[0] == 1: # If shape is already (1, x, y)
return arr
else:
raise ValueError("Array has an invalid shape")



def _conclude(self):
norm = self.n_frames
if self.norm in ["rdf", "density"]:
Expand Down

0 comments on commit a897ca9

Please sign in to comment.