diff --git a/src/dot/fomm/modules/util.py b/src/dot/fomm/modules/util.py index 86ac63e..f0d7074 100644 --- a/src/dot/fomm/modules/util.py +++ b/src/dot/fomm/modules/util.py @@ -264,7 +264,8 @@ def __init__(self, channels, scale): # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] + [torch.arange(size, dtype=torch.float32) for size in kernel_size], + indexing="xy", ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 diff --git a/src/dot/simswap/models/base_model.py b/src/dot/simswap/models/base_model.py index 18f3301..2ad552d 100644 --- a/src/dot/simswap/models/base_model.py +++ b/src/dot/simswap/models/base_model.py @@ -61,7 +61,7 @@ def load_network(self, network, network_label, epoch_label, save_dir=""): raise ("Generator must exist!") else: try: - network.load_state_dict(torch.load(save_path, strict=False)) + network.load_state_dict(torch.load(save_path), strict=False) except Exception as e: print(e) pretrained_dict = torch.load(save_path) diff --git a/src/dot/simswap/util/reverse2original.py b/src/dot/simswap/util/reverse2original.py index e081a77..fab546a 100644 --- a/src/dot/simswap/util/reverse2original.py +++ b/src/dot/simswap/util/reverse2original.py @@ -56,7 +56,9 @@ def __init__(self, kernel_size=15, threshold=0.6, iterations=1): # Create kernel y_indices, x_indices = torch.meshgrid( - torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size) + torch.arange(0.0, kernel_size), + torch.arange(0.0, kernel_size), + indexing="xy", ) dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) kernel = dist.max() - dist