Skip to content

Commit

Permalink
fix tt dataset loading
Browse files Browse the repository at this point in the history
  • Loading branch information
LumiOwO committed May 7, 2023
1 parent 969e5d5 commit 6b32a57
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ output*/
.vs
.vscode
__pycache__/
logs*/
logs*
wandb/
libtorch
3 changes: 2 additions & 1 deletion denoiser/configs/blender.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ preload = true
nx = 10
ny = 10

use_wandb = true
use_wandb = false
i_print = 1
i_save = 100
i_test = 100
save_image = false

lr = 0.0001
epochs = 2000
Expand Down
11 changes: 6 additions & 5 deletions denoiser/configs/tt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ data_dir = ./data/TanksAndTemple/Barn
dataset_type = tt
spp = 4
preload = false
nx = 1
ny = 1
nx = 10
ny = 10

use_wandb = false
use_wandb = true
i_print = 1
i_save = 100
i_test = 1
i_test = 100
save_image = false

lr = 0.0001
epochs = 2000
batch_size = 1
batch_size = 32

in_channels = 8
mid_channels = 32
Expand Down
60 changes: 38 additions & 22 deletions denoiser/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import gc
import json
import imageio
import numpy as np
Expand All @@ -9,19 +10,25 @@
from torch.utils.data import Dataset, DataLoader

class DenoiserDatasetSplit(Dataset):
def __init__(self, aux_buffer, imgs_in, imgs_gt, device):
def __init__(self, aux_buffer, imgs_in, imgs_gt, device, istorch):
self.aux_buffer = aux_buffer
self.imgs_in = imgs_in
self.imgs_gt = imgs_gt
self.device = device
self.istorch = istorch

def __len__(self):
return len(self.aux_buffer)

def __getitem__(self, idx):
return self.aux_buffer[idx].to(self.device), \
self.imgs_in[idx].to(self.device), \
self.imgs_gt[idx].to(self.device)
if self.istorch:
return self.aux_buffer[idx].to(self.device), \
self.imgs_in[idx].to(self.device), \
self.imgs_gt[idx].to(self.device)
else:
return torch.from_numpy(self.aux_buffer[idx]).to(self.device), \
torch.from_numpy(self.imgs_in[idx]).to(self.device), \
torch.from_numpy(self.imgs_gt[idx]).to(self.device)

class DenoiserDataset():
def __init__(self, args, device=None):
Expand All @@ -33,24 +40,30 @@ def __init__(self, args, device=None):

aux_buffer, imgs_in, imgs_gt = self.load_images(args)

tqdm.write("From numpy to torch...")
for s in aux_buffer.keys():
self.aux_buffer[s] = torch.stack([
torch.from_numpy(x[:args.in_channels, ...]) # [C, H, W]
for x in aux_buffer[s]])
self.imgs_in[s] = torch.stack([
torch.from_numpy(x) # [H, W, 4]
for x in imgs_in[s]])
self.imgs_gt[s] = torch.stack([
torch.from_numpy(x) # [H, W, 4]
for x in imgs_gt[s]])

if args.preload:
tqdm.write("Moving to cuda...")
if self.istorch:
tqdm.write("From numpy to torch...")
for s in aux_buffer.keys():
self.aux_buffer[s] = self.aux_buffer[s].to(device)
self.imgs_in[s] = self.imgs_in[s].to(device)
self.imgs_gt[s] = self.imgs_gt[s].to(device)
self.aux_buffer[s] = torch.stack([
torch.from_numpy(x[:args.in_channels, ...]) # [C, H, W]
for x in aux_buffer[s]])
self.imgs_in[s] = torch.stack([
torch.from_numpy(x) # [H, W, 4]
for x in imgs_in[s]])
self.imgs_gt[s] = torch.stack([
torch.from_numpy(x) # [H, W, 4]
for x in imgs_gt[s]])

if args.preload:
tqdm.write("Moving to cuda...")
for s in aux_buffer.keys():
self.aux_buffer[s] = self.aux_buffer[s].to(device)
self.imgs_in[s] = self.imgs_in[s].to(device)
self.imgs_gt[s] = self.imgs_gt[s].to(device)
else:
self.aux_buffer = aux_buffer
self.imgs_in = imgs_in
self.imgs_gt = imgs_gt


def load_images(self, args):
raise NotImplementedError()
Expand Down Expand Up @@ -131,7 +144,8 @@ def valid_chunk(img_gt_chunk):

def dataloader(self, task):
dataset = DenoiserDatasetSplit(
self.aux_buffer[task], self.imgs_in[task], self.imgs_gt[task], self.device)
self.aux_buffer[task], self.imgs_in[task], self.imgs_gt[task],
self.device, self.istorch)
loader = DataLoader(dataset,
shuffle=(task == "train"),
batch_size=(self.args.batch_size if task == "train" else 1),
Expand All @@ -141,6 +155,7 @@ def dataloader(self, task):

class BlenderDataset(DenoiserDataset):
def load_images(self, args):
self.istorch = True
aux_buffers = {}
imgs_in = {}
imgs_gt = {}
Expand Down Expand Up @@ -190,6 +205,7 @@ def load_images(self, args):

class TanksAndTemplesDataset(DenoiserDataset):
def load_images(self, args):
self.istorch = False
aux_buffers = {}
imgs_in = {}
imgs_gt = {}
Expand Down
2 changes: 2 additions & 0 deletions denoiser/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def main(args):
help="frequency of weight ckpt saving")
parser.add_argument("--i_test", type=int, default=50000,
help="frequency of testset saving")
parser.add_argument("--save_image", action="store_true",
help="save test images")

# training options
parser.add_argument("--in_channels", type=int, default=8,
Expand Down
4 changes: 2 additions & 2 deletions denoiser/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def compact_and_compile(model: GuidanceNet, device=None):
B, C, H, W = 1, 8, 800, 800
aux_buffer = torch.rand((B, C, H, W)).to(device)

profile = True
profile = False
if profile:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
with torch.no_grad():
Expand All @@ -203,7 +203,7 @@ def cast_and_forward(aux_buffer):
with torch.no_grad():
guidance_net_ts = torch.jit.trace(cast_and_forward, (aux_buffer))
# guidance_net_ts = torch.jit.optimize_for_inference(guidance_net_ts)
print(guidance_net_ts.code)
# print(guidance_net_ts.code)

if profile:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
Expand Down
11 changes: 7 additions & 4 deletions denoiser/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def test(self, model, load_ckpt=True, save_dirname="test"):

def test_one_epoch(self, model, dataloader, save_dirname):
# batch_size == 1 when testing
save_dir = os.path.join(self.args.work_dir, save_dirname)
os.makedirs(save_dir, exist_ok=True)

if self.args.save_image:
save_dir = os.path.join(self.args.work_dir, save_dirname)
os.makedirs(save_dir, exist_ok=True)

# Test full model
for m in self.metrics:
Expand Down Expand Up @@ -164,8 +166,9 @@ def test_one_epoch(self, model, dataloader, save_dirname):
for m in self.metrics:
m.measure(img_out[..., :3], img_gt[..., :3])

img_out[..., -1:] = 1
self.logger.log_image(img_out, save_dir, "r", batch_idx, {"epoch": self.epoch})
if self.args.save_image:
img_out[..., -1:] = 1
self.logger.log_image(img_out, save_dir, "r", batch_idx, {"epoch": self.epoch})

avg_loss = avg_loss / len(dataloader)
compact_model_log = {
Expand Down
10 changes: 5 additions & 5 deletions render.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ set -e
# export TREE=../data/nerf_synthetic/lego/tree.npz
# export POSES=../data/nerf_synthetic/lego/transforms_test.json
# export OUT_DIR=../data/nerf_synthetic/lego/spp_4/test
# export TS_MODULE=../logs/lego/rural-armadillo-2/ts_latest.ts
# export TS_MODULE=../logs/ts_latest.ts
# export OPTIONS=../renderer/options/blender.json

export DATASET=tt
export TREE=../data/TanksAndTemple/Family/tree.npz
export POSES=../data/TanksAndTemple/Family
export OUT_DIR=../data/TanksAndTemple/Family/spp_4
export TS_MODULE=../logs/lego/rural-armadillo-2/ts_latest.ts
export TREE=../data/TanksAndTemple/Barn/tree.npz
export POSES=../data/TanksAndTemple/Barn
export OUT_DIR=../data/TanksAndTemple/Barn/spp_4
export TS_MODULE=../logs/ts_latest.ts
export OPTIONS=../renderer/options/blender.json

export FPS_ONLY=false
Expand Down
2 changes: 1 addition & 1 deletion renderer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required( VERSION 3.3 )
cmake_minimum_required( VERSION 3.18 )

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand Down
31 changes: 15 additions & 16 deletions renderer/main_headless.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,21 @@ int main(int argc, char *argv[])
if (!out_dir.size()) {
continue;
}

// write image
{

if (args["write_buffer"].as<bool>()) {
// write auxiliary buffer
const size_t SIZE =
sizeof(float) * RenderContext::CHANNELS * width * height;
cuda(Memcpy(
buf.data(), ctx.aux_buffer, SIZE, cudaMemcpyDeviceToHost));

auto outfile = std::ofstream(
out_dir + "/buf_" + basenames[i] + ".bin",
std::ios::out | std::ios::binary);
outfile.write((char *)buf.data(), SIZE);
outfile.close();
} else {
// write image
cuda(Memcpy2DFromArray(
buf.data(),
sizeof(float4) * width,
Expand All @@ -371,20 +383,7 @@ int main(int argc, char *argv[])
std::string fpath = out_dir + "/" + basenames[i] + ".png";
internal::write_png_file(fpath, buf_uint8.data(), width, height);
}

// write auxiliary buffer
if (args["write_buffer"].as<bool>()) {
const size_t SIZE =
sizeof(float) * RenderContext::CHANNELS * width * height;
cuda(Memcpy(
buf.data(), ctx.aux_buffer, SIZE, cudaMemcpyDeviceToHost));

auto outfile = std::ofstream(
out_dir + "/buf_" + basenames[i] + ".bin",
std::ios::out | std::ios::binary);
outfile.write((char *)buf.data(), SIZE);
outfile.close();
}
}
cudaEventRecord(stop, stream);
cudaEventSynchronize(stop);
Expand Down
1 change: 0 additions & 1 deletion renderer/options/blender.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
0,
24
],
"delta_tracking": true,
"denoise": false,
"spp": 4,
"enable_probe": false,
Expand Down
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.11.0+cu113
torchaudio==0.11.0+cu113
torchvision==0.12.0+cu113
cmake==3.22.1
pytorch-msssim
wandb
ninja
lpips
imageio
tqdm
configargparse
Empty file modified train.sh
100644 → 100755
Empty file.

0 comments on commit 6b32a57

Please sign in to comment.