Skip to content

Commit

Permalink
solving warnings and minor errors (#114)
Browse files Browse the repository at this point in the history
* solving warnings and minor errors

* Reformat files

---------

Co-authored-by: Ghassen-Chaabouni <[email protected]>
  • Loading branch information
giorgiop and Ghassen-Chaabouni authored Sep 18, 2023
1 parent e13d4b0 commit 2bb1ca6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/dot/fomm/modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def __init__(self, channels, scale):
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[torch.arange(size, dtype=torch.float32) for size in kernel_size]
[torch.arange(size, dtype=torch.float32) for size in kernel_size],
indexing="xy",
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
Expand Down
2 changes: 1 addition & 1 deletion src/dot/simswap/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_network(self, network, network_label, epoch_label, save_dir=""):
raise ("Generator must exist!")
else:
try:
network.load_state_dict(torch.load(save_path, strict=False))
network.load_state_dict(torch.load(save_path), strict=False)
except Exception as e:
print(e)
pretrained_dict = torch.load(save_path)
Expand Down
4 changes: 3 additions & 1 deletion src/dot/simswap/util/reverse2original.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):

# Create kernel
y_indices, x_indices = torch.meshgrid(
torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size)
torch.arange(0.0, kernel_size),
torch.arange(0.0, kernel_size),
indexing="xy",
)
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
kernel = dist.max() - dist
Expand Down

0 comments on commit 2bb1ca6

Please sign in to comment.