diff --git a/tests/models/test_diffusion_mace.py b/tests/models/test_diffusion_mace.py index e03113ae..f6e7ef46 100644 --- a/tests/models/test_diffusion_mace.py +++ b/tests/models/test_diffusion_mace.py @@ -6,7 +6,8 @@ from crystal_diffusion.models.diffusion_mace import (DiffusionMACE, LinearVectorReadoutBlock, input_to_diffusion_mace) -from crystal_diffusion.namespace import (NOISE, NOISY_CARTESIAN_POSITIONS, +from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, + NOISY_CARTESIAN_POSITIONS, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) from crystal_diffusion.utils.basis_transformations import ( @@ -88,12 +89,17 @@ def noises(self, batch_size): return 0.5 * torch.rand(batch_size, 1) @pytest.fixture(scope='class') - def batch(self, relative_coordinates, cartesian_positions, basis_vectors, times, noises): + def forces(self, batch_size, spatial_dimension): + return 0.5 * torch.rand(batch_size, spatial_dimension) + + @pytest.fixture(scope='class') + def batch(self, relative_coordinates, cartesian_positions, basis_vectors, times, noises, forces): batch = {NOISY_RELATIVE_COORDINATES: relative_coordinates, NOISY_CARTESIAN_POSITIONS: cartesian_positions, TIME: times, NOISE: noises, - UNIT_CELL: basis_vectors} + UNIT_CELL: basis_vectors, + CARTESIAN_FORCES: forces} return batch @pytest.fixture(scope='class')