Skip to content

Commit

Permalink
device placement fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CatEek committed Feb 7, 2025
1 parent 02ee665 commit a496942
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/careamics/lvae_training/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def get_first_index(bin_count, quantile):
return None


def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"


def show_for_one(
idx,
val_dset,
Expand Down Expand Up @@ -553,6 +562,8 @@ def get_single_file_predictions(
if tile_size and grid_size:
dset.set_img_sz(tile_size, grid_size)

device = get_device()

dloader = DataLoader(
dset,
pin_memory=False,
Expand All @@ -561,14 +572,14 @@ def get_single_file_predictions(
batch_size=batch_size,
)
model.eval()
model.cuda()
model.to(device)
tiles = []
logvar_arr = []
with torch.no_grad():
for batch in tqdm(dloader, desc="Predicting tiles"):
inp, tar = batch
inp = inp.cuda()
tar = tar.cuda()
inp = inp.to(device)
tar = tar.to(device)

# get model output
rec, _ = model(inp)
Expand Down Expand Up @@ -597,6 +608,8 @@ def get_single_file_mmse(
num_workers: int = 4,
) -> tuple[np.ndarray, np.ndarray]:
"""Get patch-wise predictions from a model for a single file dataset."""
device = get_device()

dloader = DataLoader(
dset,
pin_memory=False,
Expand All @@ -608,15 +621,15 @@ def get_single_file_mmse(
dset.set_img_sz(tile_size, grid_size)

model.eval()
model.cuda()
model.to(device)
tile_mmse = []
tile_stds = []
logvar_arr = []
with torch.no_grad():
for batch in tqdm(dloader, desc="Predicting tiles"):
inp, tar = batch
inp = inp.cuda()
tar = tar.cuda()
inp = inp.to(device)
tar = tar.to(device)

rec_img_list = []
for _ in range(mmse_count):
Expand Down

0 comments on commit a496942

Please sign in to comment.