diff --git a/tests/test_pulse_shape.py b/tests/test_pulse_shape.py index a79a948..1d077ab 100644 --- a/tests/test_pulse_shape.py +++ b/tests/test_pulse_shape.py @@ -55,16 +55,25 @@ def test__get_double_exponential_shape(): def test__get_secant_shape(): theta = np.array([1, 2, 3]) expected_result = np.array([0.20628208, 0.08460748, 0.03161706]) - assert np.allclose(BlobShapeImpl._get_secant_shape(theta), expected_result) + + ps = BlobShapeImpl("secant", "secant") + values = ps.get_blob_shape_perp(theta) + assert np.max(np.abs(values - expected_result)) < 1e-5, "Wrong shape" def test__get_lorentz_shape(): theta = np.array([1, 2, 3]) expected_result = np.array([0.15915494, 0.06366198, 0.03183099]) - assert np.allclose(BlobShapeImpl._get_lorentz_shape(theta), expected_result) + + ps = BlobShapeImpl("lorentz", "lorentz") + values = ps.get_blob_shape_perp(theta) + assert np.max(np.abs(values - expected_result)) < 1e-5, "Wrong shape" def test__get_dipole_shape(): theta = np.array([1, 2, 3]) expected_result = np.array([-0.48394145, -0.21596387, -0.02659109]) - assert np.allclose(BlobShapeImpl._get_dipole_shape(theta), expected_result) + + ps = BlobShapeImpl("dipole", "dipole") + values = ps.get_blob_shape_perp(theta) + assert np.max(np.abs(values - expected_result)) < 1e-5, "Wrong shape"