diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py index fafdec0b..6f5fb935 100755 --- a/models/pix2pixHD_model.py +++ b/models/pix2pixHD_model.py @@ -108,7 +108,7 @@ def initialize(self, opt): params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) - def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: @@ -140,6 +140,39 @@ def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, inst_map = label_map.cuda() return input_label, inst_map, real_image, feat_map + + def cpu_encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cpu() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cpu(), 1.0) + if self.opt.data_type == 16: + input_label = input_label.half() + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cpu() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cpu()) + + # instance map for feature encoding + if self.use_features: + # get precomputed feature maps + if self.opt.load_features: + feat_map = Variable(feat_map.data.cpu()) + if self.opt.label_feat: + inst_map = label_map.cpu() + + return input_label, inst_map, real_image, feat_map def discriminate(self, input_label, test_image, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) @@ -216,6 +249,30 @@ def inference(self, label, inst, image=None): fake_image = self.netG.forward(input_concat) return fake_image + def cpu_inference(self, label, inst, image=None): + # Encode Inputs + image = Variable(image) if image is not None else None + input_label, inst_map, real_image, _ = self.cpu_encode_input(Variable(label), Variable(inst), image, infer=True) + + # Fake Generation + if self.use_features: + if self.opt.use_encoded_image: + # encode the real image to get feature map + feat_map = self.netE.forward(real_image, inst_map) + else: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + + if torch.__version__.startswith('0.4'): + with torch.no_grad(): + fake_image = self.netG.forward(input_concat) + else: + fake_image = self.netG.forward(input_concat) + return fake_image + def sample_features(self, inst): # read precomputed feature clusters cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) @@ -261,6 +318,7 @@ def encode_features(self, image, inst): def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() + edge = edge.bool() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) diff --git a/test.py b/test.py index e0b1ec33..4724d76d 100755 --- a/test.py +++ b/test.py @@ -55,8 +55,11 @@ generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) elif opt.onnx: generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) - else: - generated = model.inference(data['label'], data['inst'], data['image']) + else: + if opt.gpu_ids = []: + generated = model.cpu_inference(data['label'], data['inst'], data['image']) + else: + generated = model.inference(data['label'], data['inst'], data['image']) visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0]))])