diff --git a/train.py b/train.py index 78e70d6..81b7888 100644 --- a/train.py +++ b/train.py @@ -169,7 +169,7 @@ def calculate_psnr(original, reconstructed): img = img.data.squeeze().cpu() ax = fig.add_subplot(spec[0]) - ax.imshow(img.permute(1, 2, 0), cmap="gray", vmin=0, vmax=1) + ax.imshow(img.squeeze(), cmap="gray", vmin=0, vmax=1) ax.axis("off") ax.set_title("Input") @@ -177,7 +177,7 @@ def calculate_psnr(original, reconstructed): img = img.data.squeeze().cpu() ax = fig.add_subplot(spec[1]) - ax.imshow(img.permute(1, 2, 0), cmap="gray", vmin=0, vmax=1) + ax.imshow(img.squeeze(), cmap="gray", vmin=0, vmax=1) ax.axis("off") ax.set_title("Reconstructed") @@ -185,7 +185,7 @@ def calculate_psnr(original, reconstructed): img = img.data.squeeze().cpu() ax = fig.add_subplot(spec[2]) - ax.imshow(img.permute(1, 2, 0), cmap="gray", vmin=0, vmax=1) + ax.imshow(img.squeeze(), cmap="gray", vmin=0, vmax=1) ax.axis("off") ax.set_title("Ground Truth")