diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py index 20b067c..42d2146 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py @@ -92,7 +92,8 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): assert ( len(relative_coordinates_shape) == 3 and relative_coordinates_shape[2] == self.spatial_dimension - ), "The relative coordinates are expected to be in a tensor of shape [batch_size, number_of_atoms, 3]" + ), ("The relative coordinates are expected to be in a tensor of " + "shape [batch_size, number_of_atoms, spatial_dimension]") assert torch.logical_and( relative_coordinates >= 0.0, relative_coordinates < 1.0