Skip to content

Commit

Permalink
add fix for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Jun Gao committed Oct 16, 2022
1 parent 18c35f1 commit 6fda9cd
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions training/inference_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,22 @@ def inference(
conv2d_gradfix.enabled = True # Improves training speed.
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.


common_kwargs = dict(
c_dim=0, img_resolution=1024, img_channels=3)
c_dim=0, img_resolution=training_set_kwargs['resolution'] if 'resolution' in training_set_kwargs else 1024, img_channels=3)
G_kwargs['device'] = device

G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(
device) # subclass of torch.nn.Module
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(
device) # subclass of torch.nn.Module
# D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(
# device) # subclass of torch.nn.Module
G_ema = copy.deepcopy(G).eval() # deepcopy can make sure they are correct.
if resume_pretrain is not None and (rank == 0):
print('==> resume from pretrained path %s' % (resume_pretrain))
model_state_dict = torch.load(resume_pretrain, map_location=device)
G.load_state_dict(model_state_dict['G'], strict=True)
G_ema.load_state_dict(model_state_dict['G_ema'], strict=True)
D.load_state_dict(model_state_dict['D'], strict=True)
# D.load_state_dict(model_state_dict['D'], strict=True)
grid_size = (5, 5)
n_shape = grid_size[0] * grid_size[1]
grid_z = torch.randn([n_shape, G.z_dim], device=device).split(1) # random code for geometry
Expand Down

0 comments on commit 6fda9cd

Please sign in to comment.