diff --git a/tests/test_pf2d.py b/tests/test_pf2d.py index 7a65979..36f4105 100644 --- a/tests/test_pf2d.py +++ b/tests/test_pf2d.py @@ -81,9 +81,10 @@ def run_simulation(tmp_path, tvf, solver): return data_path -def get_solution(data_path, dir, t_dimless, y_axis): +def get_solution(data_path, t_dimless, y_axis): from jax_sph.utils import sph_interpolator + dir = os.listdir(data_path)[0] cfg = OmegaConf.load(data_path / dir / "config.yaml") step_max = np.array(np.rint(cfg.solver.t_end / cfg.solver.dt), dtype=int) digits = len(str(step_max)) @@ -107,11 +108,8 @@ def test_pf2d(tvf, solver, tmp_path, setup_simulation): """Test whether the poiseuille flow simulation matches the analytical solution""" y_axis, t_dimless, ref_solutions = setup_simulation data_path = run_simulation(tmp_path, tvf, solver) - subdirs = os.listdir(data_path) # print(f"tmp_path = {tmp_path}, subdirs = {subdirs}") - solutions = get_solution(data_path, subdirs[0], t_dimless, y_axis) + solutions = get_solution(data_path, t_dimless, y_axis) # print(f"solution: {solutions[-1]} \nref_solution: {ref_solutions[-1]}") - for solution, ref_solution in zip(solutions, ref_solutions): - assert np.allclose( - solution, ref_solution, atol=1e-2 - ), "Velocity profile does not match." + for sol, ref_sol in zip(solutions, ref_solutions): + assert np.allclose(sol, ref_sol, atol=1e-2), "Velocity profile does not match."