Skip to content

Commit

Permalink
(#19) Save latent (encoded) tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandru-dinu committed Jul 30, 2021
1 parent 7530464 commit adb17d1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 34 deletions.
1 change: 1 addition & 0 deletions configs/test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
checkpoint: /home/alex/workspace/ml-data/cae/model_yt_small_final.state
exp_name: testing
batch_size: 1
batch_every: 100
shuffle: false
dataset_path: /home/alex/workspace/ml-data/cae/datasets/testing
Expand Down
1 change: 1 addition & 0 deletions src/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __getitem__(self, index: int) -> Tuple[T.Tensor, np.ndarray, str]:
img = np.transpose(img, (2, 0, 1))
img = T.from_numpy(img).float()

# channels x 6 x 10 x 128 x 128 (6x10 128x128 patches)
patches = np.reshape(img, (3, 6, 128, 10, 128))
patches = np.transpose(patches, (0, 1, 3, 2, 4))

Expand Down
51 changes: 30 additions & 21 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import os
import yaml
import argparse
import os
from pathlib import Path

import numpy as np
import torch as T
import torch.nn as nn
import yaml
from torch.utils.data import DataLoader

from cae import CAE
from data_loader import ImageFolder720p
from utils import save_imgs

from namespace import Namespace
from logger import Logger

from models.cae_32x32x32_zero_pad_bin import CAE
from namespace import Namespace
from utils import save_imgs

ROOT_EXP_DIR = Path(__file__).resolve().parents[1] / "experiments"

Expand All @@ -38,7 +36,9 @@ def test(cfg: Namespace) -> None:
logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

dataloader = DataLoader(
dataset=ImageFolder720p(cfg.dataset_path), batch_size=1, shuffle=cfg.shuffle
dataset=ImageFolder720p(cfg.dataset_path),
batch_size=cfg.batch_size,
shuffle=cfg.shuffle,
)
logger.info(f"[dataset={cfg.dataset_path}]")

Expand All @@ -49,33 +49,42 @@ def test(cfg: Namespace) -> None:
if cfg.device == "cuda":
patches = patches.cuda()

if batch_idx % cfg.batch_every == 0:
pass

out = T.zeros(6, 10, 3, 128, 128)
"""
img: 6 x 10 x 3 x 128 x 128
latent: 6 x 10 x 32 x 32 x 32
"""
reconstructed_img = T.zeros(6, 10, 3, 128, 128)
encoded_data = T.zeros(6, 10, *model.encoded_shape)
avg_loss = 0

for i in range(6):
for j in range(10):
x = patches[:, :, i, j, :, :].cuda()
y = model(x)
out[i, j] = y.data

loss = loss_criterion(y, x)
y_enc = model.encode(x)
y_dec = model.decode(y_enc)

encoded_data[i, j] = y_enc.data
reconstructed_img[i, j] = y_dec.data

loss = loss_criterion(y_dec, x)
avg_loss += (1 / 60) * loss.item()

logger.debug("[%5d/%5d] avg_loss: %f", batch_idx, len(dataloader), avg_loss)

# save output
out = np.transpose(out, (0, 3, 1, 4, 2))
out = np.reshape(out, (768, 1280, 3))
out = np.transpose(out, (2, 0, 1))
reconstructed_img = np.transpose(reconstructed_img, (0, 3, 1, 4, 2))
reconstructed_img = np.reshape(reconstructed_img, (768, 1280, 3))
reconstructed_img = np.transpose(reconstructed_img, (2, 0, 1))

# TODO: make custom file-type (header, packing etc.)
T.save(encoded_data, exp_dir / f"out/enc_{batch_idx}.pt")

y = T.cat((img[0], out), dim=2)
both = T.cat((img[0], reconstructed_img), dim=2)
save_imgs(
imgs=y.unsqueeze(0),
imgs=both.unsqueeze(0),
to_size=(3, 768, 2 * 1280),
name=exp_dir / f"out/test_{batch_idx}.png",
path=exp_dir / f"out/test_{batch_idx}.png",
)


Expand Down
14 changes: 6 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import os
import yaml
import argparse
import os
from pathlib import Path

import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
import yaml
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from cae import CAE
from data_loader import ImageFolder720p
from utils import save_imgs

from namespace import Namespace
from logger import Logger

from models.cae_32x32x32_zero_pad_bin import CAE
from namespace import Namespace
from utils import save_imgs

logger = Logger(__name__, colorize=True)

Expand Down Expand Up @@ -122,7 +120,7 @@ def train(cfg: Namespace) -> None:
save_imgs(
imgs=y,
to_size=(3, 768, 2 * 1280),
name=exp_dir / f"out/{epoch_idx}_{batch_idx}.png",
path=exp_dir / f"out/{epoch_idx}_{batch_idx}.png",
)
# -- end save every
# -- end batches
Expand Down
13 changes: 8 additions & 5 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import struct
from os import PathLike
from typing import Tuple, Union

import numpy as np
from torchvision.utils import save_image
import torch as T
import torchvision


def save_imgs(imgs, to_size, name) -> None:
def save_imgs(imgs: T.Tensor, to_size: Tuple[int], path: Union[str, PathLike]) -> None:
# x = np.array(x)
# x = np.transpose(x, (1, 2, 0)) * 255
# x = x.astype(np.uint8)
Expand All @@ -12,9 +16,8 @@ def save_imgs(imgs, to_size, name) -> None:
# x = 0.5 * (x + 1)

# to_size = (C, H, W)
imgs = imgs.clamp(0, 1)
imgs = imgs.view(imgs.size(0), *to_size)
save_image(imgs, name)
imgs = imgs.clamp(0, 1).view(imgs.size(0), *to_size)
torchvision.utils.save_image(imgs, path)


def save_encoded(enc: np.ndarray, fname: str) -> None:
Expand Down

0 comments on commit adb17d1

Please sign in to comment.