diff --git a/README.md b/README.md
index 446026e3..307b0505 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ LaMa generalizes surprisingly well to much higher resolutions (~2k❗️) than i
-- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](geomagical.com))
+- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](https://www.geomagical.com))
diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py
index d9d3cbac..fa298bfd 100644
--- a/saicinpainting/evaluation/refinement.py
+++ b/saicinpainting/evaluation/refinement.py
@@ -83,6 +83,26 @@ def _l1_loss(
loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8]))
return loss
+def feats_type_to_list(feats, feats_type):
+ """unpacks the tuple of features into a list"""
+ if feats_type == tuple:
+ feats = list(feats)
+ elif feats_type == torch.Tensor:
+ feats = [feats]
+ else:
+ raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
+ return feats
+
+def list_to_feats_type(feats, feats_type):
+ """packs the list of features into the original feature type"""
+ if feats_type == tuple:
+ feats = tuple(feats)
+ elif feats_type == torch.Tensor:
+ feats = feats[0]
+ else:
+ raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
+ return feats
+
def _infer(
image : torch.Tensor, mask : torch.Tensor,
forward_front : nn.Module, forward_rears : nn.Module,
@@ -125,27 +145,30 @@ def _infer(
if ref_lower_res is not None:
ref_lower_res = ref_lower_res.detach()
with torch.no_grad():
- z1,z2 = forward_front(masked_image)
+ z_feats = forward_front(masked_image)
+ z_feats_type = type(z_feats)
+ z_feats = feats_type_to_list(z_feats, z_feats_type)
# Inference
mask = mask.to(devices[-1])
ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
ekernel = ekernel.to(devices[-1])
image = image.to(devices[-1])
- z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
- z1.requires_grad, z2.requires_grad = True, True
+ z_feats = [z_feat.detach().to(devices[0]) for z_feat in z_feats]
+ for z_feat in z_feats:
+ z_feat.requires_grad = True
- optimizer = Adam([z1,z2], lr=lr)
+ optimizer = Adam(z_feats, lr=lr)
pbar = tqdm(range(n_iters), leave=False)
for idi in pbar:
optimizer.zero_grad()
- input_feat = (z1,z2)
+ input_feat = list_to_feats_type(z_feats, z_feats_type)
for idd, forward_rear in enumerate(forward_rears):
output_feat = forward_rear(input_feat)
if idd < len(devices) - 1:
- midz1, midz2 = output_feat
- midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1])
- input_feat = (midz1, midz2)
+ mid_z_feats = feats_type_to_list(output_feat, z_feats_type)
+ mid_z_feats = [mid_z_feat.to(devices[idd+1]) for mid_z_feat in mid_z_feats]
+ input_feat = list_to_feats_type(mid_z_feats, z_feats_type)
else:
pred = output_feat