Skip to content

Commit

Permalink
fixing unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Jun 10, 2024
1 parent 18cb6a2 commit e74cbf2
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/models/test_diffusion_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from crystal_diffusion.models.diffusion_mace import (DiffusionMACE,
LinearVectorReadoutBlock,
input_to_diffusion_mace)
from crystal_diffusion.namespace import (NOISE, NOISY_CARTESIAN_POSITIONS,
from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE,
NOISY_CARTESIAN_POSITIONS,
NOISY_RELATIVE_COORDINATES, TIME,
UNIT_CELL)
from crystal_diffusion.utils.basis_transformations import (
Expand Down Expand Up @@ -88,12 +89,17 @@ def noises(self, batch_size):
return 0.5 * torch.rand(batch_size, 1)

@pytest.fixture(scope='class')
def batch(self, relative_coordinates, cartesian_positions, basis_vectors, times, noises):
def forces(self, batch_size, spatial_dimension):
return 0.5 * torch.rand(batch_size, spatial_dimension)

@pytest.fixture(scope='class')
def batch(self, relative_coordinates, cartesian_positions, basis_vectors, times, noises, forces):
batch = {NOISY_RELATIVE_COORDINATES: relative_coordinates,
NOISY_CARTESIAN_POSITIONS: cartesian_positions,
TIME: times,
NOISE: noises,
UNIT_CELL: basis_vectors}
UNIT_CELL: basis_vectors,
CARTESIAN_FORCES: forces}
return batch

@pytest.fixture(scope='class')
Expand Down

0 comments on commit e74cbf2

Please sign in to comment.