Skip to content

Commit

Permalink
Add tests for IdentityFunction
Browse files Browse the repository at this point in the history
This is basic.
  • Loading branch information
mnwhite committed Dec 10, 2024
1 parent de873dc commit 20e00a3
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions HARK/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from HARK.interpolation import BilinearInterp
from HARK.interpolation import CubicHermiteInterp as CubicInterp
from HARK.interpolation import LinearInterp, QuadlinearInterp, TrilinearInterp
from HARK.interpolation import IdentityFunction


class testsLinearInterp(unittest.TestCase):
Expand Down Expand Up @@ -200,3 +201,46 @@ def test_same_length(self):
self.f_array, self.w_array, self.x_array, self.y_array_t, self.z_array
)
self.assertEqual(bilinear(1, 2, 1, 2), 6.0)


class test_IdentityFunction(unittest.TestCase):
"""
Tests evaluation and derivatives of IdentityFunction class.
"""

def setUp(self):
self.IF1D = IdentityFunction()
self.IF2Da = IdentityFunction(i_dim=0, n_dims=2)
self.IF2Db = IdentityFunction(i_dim=1, n_dims=2)
self.IF3Da = IdentityFunction(i_dim=0, n_dims=3)
self.IF3Db = IdentityFunction(i_dim=2, n_dims=3)
self.X = 3*np.ones(100)
self.Y = 4*np.ones(100)
self.Z = 5*np.ones(100)
self.zero = np.zeros(100)
self.one = np.ones(100)

def test_eval(self):
self.assertEqual(self.X, self.IF1D(self.X))
self.assertEqual(self.X, self.IF2Da(self.X, self.Y))
self.assertEqual(self.Y, self.IF2Db(self.X, self.Y))
self.assertEqual(self.X, self.IF3Da(self.X, self.Y, self.Z))
self.assertEqual(self.Z, self.IF3Db(self.X, self.Y, self.Z))

def test_der(self):
self.assertEqual(self.one, self.IF1D.der(self.X))

self.assertEqual(self.one, self.IF2Da.derX(self.X, self.Y))
self.assertEqual(self.zero, self.IF2Da.derY(self.X, self.Y))

self.assertEqual(self.zero, self.IF2Db.derX(self.X, self.Y))
self.assertEqual(self.one, self.IF2Db.derY(self.X, self.Y))

self.assertEqual(self.one, self.IF3Da.derX(self.X, self.Y, self.Z))
self.assertEqual(self.zero, self.IF3Da.derY(self.X, self.Y, self.Z))
self.assertEqual(self.zero, self.IF3Da.derZ(self.X, self.Y, self.Z))

self.assertEqual(self.zero, self.IF3Db.derX(self.X, self.Y, self.Z))
self.assertEqual(self.zero, self.IF3Db.derY(self.X, self.Y, self.Z))
self.assertEqual(self.one, self.IF3Db.derZ(self.X, self.Y, self.Z))

0 comments on commit 20e00a3

Please sign in to comment.