diff --git a/src/qusi/infer_session.py b/src/qusi/infer_session.py index fe5decbe..78bd26c0 100644 --- a/src/qusi/infer_session.py +++ b/src/qusi/infer_session.py @@ -33,6 +33,7 @@ def get_device() -> Device: def infer_phase(dataloader, model: Module, device: Device): batch_count = 0 batches_of_predicted_targets = [] + model = model.to(device=device) model.eval() with torch.no_grad(): for input_features in dataloader: