diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py index 7d62576..db42e1e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/score_viewer.py @@ -153,7 +153,7 @@ def _compute_projected_scores(self, score_network: ScoreNetwork): for time, sigma in zip(self.times, self.sigmas): batch = self._get_batch(time, sigma) - sigma_normalized_scores = score_network(batch).X.detach() + sigma_normalized_scores = score_network(batch).X.detach().cpu() vectors = einops.rearrange( sigma_normalized_scores, "batch natoms space -> batch (natoms space)" )