-
Notifications
You must be signed in to change notification settings - Fork 146
/
goggle_video.py
63 lines (54 loc) · 1.58 KB
/
goggle_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import cv2
import sys
#!/usr/bin/python
import argparse
import os
from torchvision import datasets, transforms
import gym
import sys
import time
import time
import pygame
import pybullet as p
from gibson.core.render.profiler import Profiler
from gibson.learn.completion import CompletionNet
import cv2
import torch.nn as nn
import torch
from torch.autograd import Variable
from gibson import assets
assets_file_dir = os.path.dirname(assets.__file__)
cap = cv2.VideoCapture(sys.argv[1])
def load_model():
comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
comp = nn.DataParallel(comp).cuda()
comp.load_state_dict(
torch.load(os.path.join(assets_file_dir, "unfiller_rgb.pth")))
model = comp.module
model.eval()
return model
model = load_model()
imgv = Variable(torch.zeros(1, 3, 512, 512), volatile=True).cuda()
maskv = Variable(torch.zeros(1, 2, 512, 512), volatile=True).cuda()
while(cap.isOpened()):
ret, frame = cap.read()
if frame is None:
break
w,h,_ = frame.shape
frame = frame.transpose(1,0,2)
frame = cv2.resize(frame[h//2 - w//2:h//2 + w//2, :], (512,512))
tf = transforms.ToTensor()
source = tf(frame)
imgv.data.copy_(source)
maskv[:,0,:,:].data.fill_(0.05)
maskv[:,1,:,:].data.fill_(1)
#print(source)
#print(imgv.size(), maskv.size())
recon = model(imgv, maskv)
goggle_img = (recon.data.clamp(0, 1).cpu().numpy()[0].transpose(1, 2, 0) * 255).astype(np.uint8)
cv2.imshow('frame',goggle_img)
if cv2.waitKey(16) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()