Skip to content

Commit

Permalink
minor changes to basegrid.py so ngrid can integrate 3D points
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchezw17 committed Feb 5, 2024
1 parent 04fc416 commit 51e7503
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/grid/basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,23 @@ def integrate(self, *value_arrays):
The calculated integral over given integrand or function
"""

if len(value_arrays) < 1:
raise ValueError("No array is given to integrate.")

for i, array in enumerate(value_arrays):
if not isinstance(array, np.ndarray):
raise TypeError(f"Arg {i} is {type(i)}, Need Numpy Array.")
if array.shape != (self.size,):
if np.shape(array)[0] != (self.size):
raise ValueError(f"Arg {i} need to be of shape ({self.size},).")
# return np.einsum("i, ..., i", a, ..., z)
return np.einsum(
"i" + ",i" * len(value_arrays),
"i" + ",i..." * len(value_arrays),
self.weights,
*(array for array in value_arrays),
)


def get_localgrid(self, center, radius):
"""Create a grid contain points within the given radius of center.
Expand Down
2 changes: 1 addition & 1 deletion src/grid/ngrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def _n_integrate(self, grid_list, callable, **call_kwargs):

# calculate the value of the n particle function at each point of the last grid
vals = aux_func(grid_list[-1].points)

# Integrate the function over the last grid with all the other coordinates fixed.
# The result is multiplied by the product of the weights corresponding to the other
# grids' points (stored in i[1]).
# This is equivalent to integrating the n particle function over the coordinates of
# the last particle with the other coordinates fixed.
integral += grid_list[-1].integrate(vals) * np.prod(i[1])

return integral
40 changes: 38 additions & 2 deletions src/grid/tests/test_ngrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
from grid.onedgrid import UniformInteger
from grid.rtransform import LinearInfiniteRTransform
from grid.ngrid import Ngrid

from grid.basegrid import Grid
import numpy as np
from numpy.testing import assert_allclose


class TestNgrid(TestCase):
"""Ngrid tests class."""

Expand All @@ -44,6 +43,11 @@ def setUp(self):
)
# create a 3D grid with n equally spaced points between 0 and 1 along each axis
self.ngrid = Ngrid([self.linear_grid], 3)
ref_points = np.array([np.linspace(0, 1, 500)] * 3).T
#ref_points = np.random.uniform(0, 1, size=(100, 3))
ref_weights = np.ones(len(ref_points)) /len(ref_points)
self.grid_3d = Grid(ref_points, ref_weights)


def test_init_raises(self):
"""Assert that the init raises the correct error."""
Expand Down Expand Up @@ -115,4 +119,36 @@ def f(x, y, z):
# integrate it
result = ngrid.integrate(f)
# check that the result is correct

self.assertAlmostEqual(result, 1.0 / 8.0, places=2)

def test_3d_double_grid_integration(self):
"""Assert that the integration works as expected for two grids."""

# define a function to integrate (x**2+y**2)
def f(x, y):
return x**2+y**2

# define a Ngrid with two grids
ngrid = Ngrid(grid_list=[self.grid_3d],n=2)
# integrate it
result = ngrid.integrate(f)
result=np.sum(result)/3
# check that the result is correct
self.assertAlmostEqual(result, 2.0 / 3.0, places=2)

def test_3d_triple_grid_integration(self):
"""Assert that the integration works as expected for two grids."""

# define a function to integrate (x**2+y**2)
def f(x, y, z):
return x * y * z

# define a Ngrid with two grids
ngrid = Ngrid(grid_list=[self.grid_3d],n=3)
# integrate it
result = ngrid.integrate(f)
result=np.sum(result)/3
# check that the result is correct
self.assertAlmostEqual(result, 1.0 / 8.0, places=2)

0 comments on commit 51e7503

Please sign in to comment.