diff --git a/HARK/tests/test_interpolation.py b/HARK/tests/test_interpolation.py index 200b49d0e..2385b0168 100644 --- a/HARK/tests/test_interpolation.py +++ b/HARK/tests/test_interpolation.py @@ -9,6 +9,7 @@ from HARK.interpolation import BilinearInterp from HARK.interpolation import CubicHermiteInterp as CubicInterp from HARK.interpolation import LinearInterp, QuadlinearInterp, TrilinearInterp +from HARK.interpolation import IdentityFunction class testsLinearInterp(unittest.TestCase): @@ -200,3 +201,46 @@ def test_same_length(self): self.f_array, self.w_array, self.x_array, self.y_array_t, self.z_array ) self.assertEqual(bilinear(1, 2, 1, 2), 6.0) + + +class test_IdentityFunction(unittest.TestCase): + """ + Tests evaluation and derivatives of IdentityFunction class. + """ + + def setUp(self): + self.IF1D = IdentityFunction() + self.IF2Da = IdentityFunction(i_dim=0, n_dims=2) + self.IF2Db = IdentityFunction(i_dim=1, n_dims=2) + self.IF3Da = IdentityFunction(i_dim=0, n_dims=3) + self.IF3Db = IdentityFunction(i_dim=2, n_dims=3) + self.X = 3*np.ones(100) + self.Y = 4*np.ones(100) + self.Z = 5*np.ones(100) + self.zero = np.zeros(100) + self.one = np.ones(100) + + def test_eval(self): + self.assertEqual(self.X, self.IF1D(self.X)) + self.assertEqual(self.X, self.IF2Da(self.X, self.Y)) + self.assertEqual(self.Y, self.IF2Db(self.X, self.Y)) + self.assertEqual(self.X, self.IF3Da(self.X, self.Y, self.Z)) + self.assertEqual(self.Z, self.IF3Db(self.X, self.Y, self.Z)) + + def test_der(self): + self.assertEqual(self.one, self.IF1D.der(self.X)) + + self.assertEqual(self.one, self.IF2Da.derX(self.X, self.Y)) + self.assertEqual(self.zero, self.IF2Da.derY(self.X, self.Y)) + + self.assertEqual(self.zero, self.IF2Db.derX(self.X, self.Y)) + self.assertEqual(self.one, self.IF2Db.derY(self.X, self.Y)) + + self.assertEqual(self.one, self.IF3Da.derX(self.X, self.Y, self.Z)) + self.assertEqual(self.zero, self.IF3Da.derY(self.X, self.Y, self.Z)) + self.assertEqual(self.zero, self.IF3Da.derZ(self.X, self.Y, self.Z)) + + self.assertEqual(self.zero, self.IF3Db.derX(self.X, self.Y, self.Z)) + self.assertEqual(self.zero, self.IF3Db.derY(self.X, self.Y, self.Z)) + self.assertEqual(self.one, self.IF3Db.derZ(self.X, self.Y, self.Z)) + \ No newline at end of file