diff --git a/README.md b/README.md index 085e16b..06b7b27 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,13 @@ Download and run the dot executable for your OS: - Download `dot.zip` from [here](https://drive.google.com/file/d/10fdnSszaEbpGdCKxxeBFXQrkagxe3-RT/view), unzip it and then run `dot.exe` - Ubuntu: - ToDo -- Mac: - - ToDo +- Mac (Tested on Apple M2 Sonoma 14.0): + + - Download `dot-executable.app` from [here](https://drive.google.com/drive/folders/1n22mvWSFmXbSAspZWp5sChbCqXgOEKGx?usp=drive_link) + - Run `dot-executable.app` + - In case of camera reading error: + - Right click and choose `Show Package Contents` + - Execute `dot-executable` from `Contents/MacOS` folder #### GUI Usage @@ -118,6 +123,15 @@ Install the `torch` and `torchvision` dependencies based on the CUDA version ins To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.cuda.is_available())"`. If the output is `True`, the dependencies are installed with CUDA support. +###### With MPS Support(Apple Silicon) + +```bash +conda env create -f envs/environment-apple-m2.yaml +conda activate dot +``` + +To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.backends.mps.is_available())"`. If the output is `True`, the dependencies are installed with Metal programming framework support. + ###### With CPU Support (slow, not recommended) ```bash diff --git a/configs/faceswap_cv2.yaml b/configs/faceswap_cv2.yaml index f350097..7ea6230 100644 --- a/configs/faceswap_cv2.yaml +++ b/configs/faceswap_cv2.yaml @@ -1,3 +1,3 @@ --- swap_type: faceswap_cv2 -model_path: ./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat +model_path: saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat diff --git a/configs/fomm.yaml b/configs/fomm.yaml index f9c8530..ce66d6e 100644 --- a/configs/fomm.yaml +++ b/configs/fomm.yaml @@ -1,4 +1,4 @@ --- swap_type: fomm -model_path: ./saved_models/fomm/vox-adv-cpk.pth.tar +model_path: saved_models/fomm/vox-adv-cpk.pth.tar head_pose: true diff --git a/configs/simswap.yaml b/configs/simswap.yaml index db1bf89..2f2566c 100644 --- a/configs/simswap.yaml +++ b/configs/simswap.yaml @@ -1,5 +1,5 @@ --- swap_type: simswap -parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth -arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar -checkpoints_dir: ./saved_models/simswap/checkpoints +parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth +arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar +checkpoints_dir: saved_models/simswap/checkpoints diff --git a/configs/simswaphq.yaml b/configs/simswaphq.yaml index 8c523e2..714f31e 100644 --- a/configs/simswaphq.yaml +++ b/configs/simswaphq.yaml @@ -1,6 +1,6 @@ --- swap_type: simswap -parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth -arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar -checkpoints_dir: ./saved_models/simswap/checkpoints +parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth +arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar +checkpoints_dir: saved_models/simswap/checkpoints crop_size: 512 diff --git a/envs/environment-apple-m2.yaml b/envs/environment-apple-m2.yaml new file mode 100644 index 0000000..c72c291 --- /dev/null +++ b/envs/environment-apple-m2.yaml @@ -0,0 +1,10 @@ +--- +name: dot +channels: + - conda-forge + - defaults +dependencies: + - python=3.8 + - pip=21.3 + - pip: + - -r ../requirements-apple-m2.txt diff --git a/requirements-apple-m2.txt b/requirements-apple-m2.txt new file mode 100644 index 0000000..79446bb --- /dev/null +++ b/requirements-apple-m2.txt @@ -0,0 +1,134 @@ +# +# This file is autogenerated by pip-compile with python 3.8 +# To update, run: +# +# pip-compile setup.cfg +# +absl-py==1.1.0 + # via mediapipe +attrs==21.4.0 + # via mediapipe +certifi==2022.12.7 + # via requests +chardet==4.0.0 + # via requests +click==8.0.2 + # via dot (setup.cfg) +cycler==0.11.0 + # via matplotlib +dlib==19.19.0 + # via dot (setup.cfg) +face-alignment==1.3.3 + # via dot (setup.cfg) +flatbuffers==2.0 + # via onnxruntime +fonttools==4.34.4 + # via matplotlib +idna==2.10 + # via requests +imageio==2.19.3 + # via scikit-image +kiwisolver==1.4.3 + # via matplotlib +kornia==0.6.5 + # via dot (setup.cfg) +llvmlite==0.38.1 + # via numba +matplotlib==3.5.2 + # via mediapipe +mediapipe-silicon + # via dot (setup.cfg) +mediapipe==0.10.3 +networkx==2.8.4 + # via scikit-image +numba==0.55.2 + # via face-alignment +numpy==1.22.0 + # via + # dot (setup.cfg) + # face-alignment + # imageio + # matplotlib + # mediapipe + # numba + # onnxruntime + # opencv-contrib-python + # opencv-python + # pywavelets + # scikit-image + # scipy + # tifffile + # torchvision +onnxruntime==1.15.1 + # via dot (setup.cfg) +opencv-contrib-python==4.5.5.62 + # via + # dot (setup.cfg) + # mediapipe +opencv-python==4.5.5.62 + # via + # dot (setup.cfg) + # face-alignment +packaging==21.3 + # via + # kornia + # matplotlib + # scikit-image +pillow==9.3.0 + # via + # dot (setup.cfg) + # imageio + # matplotlib + # scikit-image + # torchvision +protobuf==3.20.2 + # via + # dot (setup.cfg) + # mediapipe + # onnxruntime +pyparsing==3.0.9 + # via + # matplotlib + # packaging +python-dateutil==2.8.2 + # via matplotlib +pywavelets==1.3.0 + # via scikit-image +pyyaml==5.4.1 + # via dot (setup.cfg) +requests==2.25.1 + # via dot (setup.cfg) +scikit-image==0.19.1 + # via + # dot (setup.cfg) + # face-alignment +scipy==1.10.1 + # via + # dot (setup.cfg) + # face-alignment + # scikit-image +six==1.16.0 + # via + # mediapipe + # python-dateutil +tifffile==2022.5.4 + # via scikit-image +torch==2.0.1 + # via + # dot (setup.cfg) + # face-alignment + # kornia + # torchvision +torchvision==0.15.2 + # via dot (setup.cfg) +tqdm==4.64.0 + # via face-alignment +typing-extensions==4.3.0 + # via torch +urllib3==1.26.10 + # via requests +wheel==0.38.1 + # via mediapipe + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/src/dot/__main__.py b/src/dot/__main__.py index 5c1e47f..b12c243 100644 --- a/src/dot/__main__.py +++ b/src/dot/__main__.py @@ -21,7 +21,7 @@ def run( arcface_model_path: str = None, checkpoints_dir: str = None, gpen_type: str = None, - gpen_path: str = "./saved_models/gpen", + gpen_path: str = "saved_models/gpen", crop_size: int = 224, head_pose: bool = False, save_folder: str = None, @@ -42,7 +42,7 @@ def run( arcface_model_path (str, optional): The path to the arcface model. Defaults to None. checkpoints_dir (str, optional): The path to the checkpoints directory. Defaults to None. gpen_type (str, optional): The type of gpen model to use. Defaults to None. - gpen_path (str, optional): The path to the gpen models. Defaults to "./saved_models/gpen". + gpen_path (str, optional): The path to the gpen models. Defaults to "saved_models/gpen". crop_size (int, optional): The size to crop the images to. Defaults to 224. save_folder (str, optional): The path to the save folder. Defaults to None. show_fps (bool, optional): Pass flag to show fps value. Defaults to False. @@ -130,7 +130,7 @@ def run( @click.option( "--gpen_path", "gpen_path", - default="./saved_models/gpen", + default="saved_models/gpen", help="Path to gpen models.", ) @click.option("--crop_size", "crop_size", type=int, default=224) @@ -185,7 +185,7 @@ def main( arcface_model_path: str = None, checkpoints_dir: str = None, gpen_type: str = None, - gpen_path: str = "./saved_models/gpen", + gpen_path: str = "saved_models/gpen", crop_size: int = 224, save_folder: str = None, show_fps: bool = False, diff --git a/src/dot/commons/model_option.py b/src/dot/commons/model_option.py index 86eede8..653198a 100644 --- a/src/dot/commons/model_option.py +++ b/src/dot/commons/model_option.py @@ -21,7 +21,7 @@ class ModelOption(ABC): def __init__( self, gpen_type=None, - gpen_path="./saved_models/gpen", + gpen_path="saved_models/gpen", use_gpu=True, crop_size=256, ): diff --git a/src/dot/faceswap_cv2/swap.py b/src/dot/faceswap_cv2/swap.py index 342c1dd..25d79af 100644 --- a/src/dot/faceswap_cv2/swap.py +++ b/src/dot/faceswap_cv2/swap.py @@ -17,9 +17,7 @@ ) # define globals -CACHED_PREDICTOR_PATH = ( - "./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat" -) +CACHED_PREDICTOR_PATH = "saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat" class Swap: diff --git a/src/dot/gpen/face_model/face_gan.py b/src/dot/gpen/face_model/face_gan.py index 165db66..23993ac 100644 --- a/src/dot/gpen/face_model/face_gan.py +++ b/src/dot/gpen/face_model/face_gan.py @@ -28,6 +28,11 @@ def __init__( self.n_mlp = 8 self.is_norm = is_norm self.resolution = size + self.device = ( + ("mps" if torch.backends.mps.is_available() else "cuda") + if use_gpu + else "cpu" + ) self.load_model( channel_multiplier=channel_multiplier, narrow=narrow, use_gpu=use_gpu ) @@ -36,8 +41,8 @@ def load_model(self, channel_multiplier=2, narrow=1, use_gpu=True): if use_gpu: self.model = FullGenerator( self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow - ).cuda() - pretrained_dict = torch.load(self.mfile) + ).to(self.device) + pretrained_dict = torch.load(self.mfile, map_location=self.device) else: self.model = FullGenerator( self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow @@ -60,7 +65,7 @@ def process(self, img, use_gpu=True): def img2tensor(self, img, use_gpu=True): if use_gpu: - img_t = torch.from_numpy(img).cuda() / 255.0 + img_t = torch.from_numpy(img).to(self.device) / 255.0 else: img_t = torch.from_numpy(img).cpu() / 255.0 if self.is_norm: diff --git a/src/dot/gpen/retinaface/layers/modules/multibox_loss.py b/src/dot/gpen/retinaface/layers/modules/multibox_loss.py index c0f0e81..0c0c69b 100644 --- a/src/dot/gpen/retinaface/layers/modules/multibox_loss.py +++ b/src/dot/gpen/retinaface/layers/modules/multibox_loss.py @@ -94,12 +94,14 @@ def forward(self, predictions, priors, targets): landm_t, idx, ) + device = "cpu" if GPU: - loc_t = loc_t.cuda() - conf_t = conf_t.cuda() - landm_t = landm_t.cuda() + device = "mps" if torch.backends.mps.is_available() else "cuda" + loc_t = loc_t.to(device) + conf_t = conf_t.to(device) + landm_t = landm_t.to(device) - zeros = torch.tensor(0).cuda() + zeros = torch.tensor(0).to(device) # landm Loss (Smooth L1) # Shape: [batch,num_priors,10] pos1 = conf_t > zeros diff --git a/src/dot/gpen/retinaface/retinaface_detection.py b/src/dot/gpen/retinaface/retinaface_detection.py index e10d6c3..a371003 100644 --- a/src/dot/gpen/retinaface/retinaface_detection.py +++ b/src/dot/gpen/retinaface/retinaface_detection.py @@ -23,14 +23,14 @@ def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True): cudnn.benchmark = True self.pretrained_path = os.path.join(base_dir, "weights", network + ".pth") if use_gpu: - self.device = torch.cuda.current_device() + self.device = "mps" if torch.backends.mps.is_available() else "cuda" else: self.device = "cpu" self.cfg = cfg_re50 self.net = RetinaFace(cfg=self.cfg, phase="test") if use_gpu: self.load_model() - self.net = self.net.cuda() + self.net = self.net.to(self.device) else: self.load_model(load_to_cpu=True) self.net = self.net.cpu() @@ -57,9 +57,10 @@ def load_model(self, load_to_cpu=False): self.pretrained_path, map_location=lambda storage, loc: storage ) else: - pretrained_dict = torch.load( - self.pretrained_path, map_location=lambda storage, loc: storage.cuda() - ) + # pretrained_dict = torch.load( + # self.pretrained_path, map_location=lambda storage, loc: storage.to("mps")#.cuda() + # ) + pretrained_dict = torch.load(self.pretrained_path, map_location=self.device) if "state_dict" in pretrained_dict.keys(): pretrained_dict = self.remove_prefix( pretrained_dict["state_dict"], "module." @@ -89,8 +90,8 @@ def detect( img = img.transpose(2, 0, 1) img = torch.from_numpy(img).unsqueeze(0) if use_gpu: - img = img.cuda() - scale = scale.cuda() + img = img.to(self.device) + scale = scale.to(self.device) else: img = img.cpu() scale = scale.cpu() @@ -100,7 +101,7 @@ def detect( priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) priors = priorbox.forward() if use_gpu: - priors = priors.cuda() + priors = priors.to(self.device) else: priors = priors.cpu() @@ -125,7 +126,7 @@ def detect( ] ) if use_gpu: - scale1 = scale1.cuda() + scale1 = scale1.to(self.device) else: scale1 = scale1.cpu() diff --git a/src/dot/simswap/configs/config.yaml b/src/dot/simswap/configs/config.yaml index b301839..210c72e 100644 --- a/src/dot/simswap/configs/config.yaml +++ b/src/dot/simswap/configs/config.yaml @@ -1,9 +1,9 @@ --- analysis: simswap: - parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth - checkpoints_dir: ./saved_models/simswap/checkpoints - arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar + parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth + checkpoints_dir: saved_models/simswap/checkpoints + arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar detection_threshold: 0.6 det_size: [640, 640] use_gpu: true @@ -19,4 +19,4 @@ analysis: opt_which_epoch: latest opt_continue_train: store_true gpen: gpen_256 - gpen_path: ./saved_models/gpen + gpen_path: saved_models/gpen diff --git a/src/dot/simswap/configs/config_512.yaml b/src/dot/simswap/configs/config_512.yaml index 92c874d..a1124e5 100644 --- a/src/dot/simswap/configs/config_512.yaml +++ b/src/dot/simswap/configs/config_512.yaml @@ -1,9 +1,9 @@ --- analysis: simswap: - parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth - checkpoints_dir: ./saved_models/simswap/checkpoints - arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar + parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth + checkpoints_dir: saved_models/simswap/checkpoints + arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar detection_threshold: 0.6 det_size: [640, 640] use_gpu: true @@ -19,4 +19,4 @@ analysis: opt_which_epoch: '550000' opt_continue_train: store_true gpen: gpen_256 - gpen_path: ./saved_models/gpen + gpen_path: saved_models/gpen diff --git a/src/dot/simswap/fs_model.py b/src/dot/simswap/fs_model.py index b019660..3a43654 100644 --- a/src/dot/simswap/fs_model.py +++ b/src/dot/simswap/fs_model.py @@ -55,7 +55,9 @@ def initialize( torch.backends.cudnn.benchmark = True if use_gpu: - device = torch.device("cuda:0") + device = torch.device( + "mps" if torch.backends.mps.is_available() else "cuda" + ) else: device = torch.device("cpu") diff --git a/src/dot/simswap/option.py b/src/dot/simswap/option.py index 47ce756..8a11969 100644 --- a/src/dot/simswap/option.py +++ b/src/dot/simswap/option.py @@ -82,8 +82,11 @@ def create_model( # type: ignore n_classes = 19 self.net = BiSeNet(n_classes=n_classes) if self.use_gpu: - self.net.cuda() - self.net.load_state_dict(torch.load(parsing_model_path)) + device = "mps" if torch.backends.mps.is_available() else "cuda" + self.net.to(device) + self.net.load_state_dict( + torch.load(parsing_model_path, map_location=device) + ) else: self.net.cpu() self.net.load_state_dict( @@ -127,7 +130,14 @@ def change_option(self, image: np.array, **kwargs) -> None: img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) # convert numpy to tensor - img_id = img_id.cuda() if self.use_gpu else img_id.cpu() + if self.use_gpu: + img_id = ( + img_id.to("mps") + if torch.backends.mps.is_available() + else img_id.to("cuda") + ) + else: + img_id = img_id.cpu() # create latent id img_id_downsample = F.interpolate(img_id, size=(112, 112)) @@ -138,7 +148,9 @@ def change_option(self, image: np.array, **kwargs) -> None: ) source_image = ( - source_image.to("cuda") if self.use_gpu else source_image.to("cpu") + source_image.to("mps" if torch.backends.mps.is_available() else "cuda") + if self.use_gpu + else source_image.to("cpu") ) self.source_image = source_image @@ -165,7 +177,9 @@ def process_image(self, image: np.array, **kwargs) -> np.array: if self.use_gpu: frame_align_crop_tenor = _totensor( cv2.cvtColor(frame_align_crop, cv2.COLOR_BGR2RGB) - )[None, ...].cuda() + )[None, ...].to( + "mps" if torch.backends.mps.is_available() else "cuda" + ) else: frame_align_crop_tenor = _totensor( cv2.cvtColor(frame_align_crop, cv2.COLOR_BGR2RGB) diff --git a/src/dot/simswap/util/norm.py b/src/dot/simswap/util/norm.py index 73fd461..4d07003 100644 --- a/src/dot/simswap/util/norm.py +++ b/src/dot/simswap/util/norm.py @@ -14,7 +14,11 @@ def __init__(self, epsilon=1e-8, use_gpu=True): super(SpecificNorm, self).__init__() self.mean = np.array([0.485, 0.456, 0.406]) if use_gpu: - self.mean = torch.from_numpy(self.mean).float().cuda() + self.mean = ( + torch.from_numpy(self.mean) + .float() + .to("mps" if torch.backends.mps.is_available() else "cuda") + ) else: self.mean = torch.from_numpy(self.mean).float().cpu() @@ -22,7 +26,11 @@ def __init__(self, epsilon=1e-8, use_gpu=True): self.std = np.array([0.229, 0.224, 0.225]) if use_gpu: - self.std = torch.from_numpy(self.std).float().cuda() + self.std = ( + torch.from_numpy(self.std) + .float() + .to("mps" if torch.backends.mps.is_available() else "cuda") + ) else: self.std = torch.from_numpy(self.std).float().cpu() diff --git a/src/dot/simswap/util/reverse2original.py b/src/dot/simswap/util/reverse2original.py index fab546a..a76d0d5 100644 --- a/src/dot/simswap/util/reverse2original.py +++ b/src/dot/simswap/util/reverse2original.py @@ -113,7 +113,9 @@ def reverse2wholeimage( use_cam=True, ): - device = torch.device("cuda" if use_gpu else "cpu") + device = torch.device( + ("mps" if torch.backends.mps.is_available() else "cuda") if use_gpu else "cpu" + ) if use_mask: smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).to( device @@ -133,12 +135,19 @@ def reverse2wholeimage( ) # invert the Affine transformation matrix - mat_rev_initial[0:2, :] = torch.tensor(mat).to(device) + if device == torch.device("mps"): + mat_rev_initial[0:2, :] = torch.tensor(mat, dtype=torch.float32).to(device) + else: + mat_rev_initial[0:2, :] = torch.tensor(mat).to(device) - if device == "cpu": + if device == torch.device("cpu"): mat_rev = torch.linalg.inv(mat_rev_initial) mat_rev = mat_rev[:2, :] mat_rev = mat_rev[None, ...] + elif device == torch.device("mps"): + mat_rev = torch.linalg.inv(mat_rev_initial) + mat_rev = mat_rev[:2, :] + mat_rev = torch.as_tensor(mat_rev[None, ...], device=device) else: import cupy as cp diff --git a/src/dot/simswap/util/util.py b/src/dot/simswap/util/util.py index 9717047..67d03b5 100644 --- a/src/dot/simswap/util/util.py +++ b/src/dot/simswap/util/util.py @@ -142,7 +142,7 @@ def load_parsing_model(path, use_mask, use_gpu): n_classes = 19 net = BiSeNet(n_classes=n_classes) if use_gpu: - net.cuda() + net.to("mps" if torch.backends.mps.is_available() else "cuda") net.load_state_dict(torch.load(path)) else: net.cpu() @@ -170,7 +170,13 @@ def crop_align( img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) # convert numpy to tensor - img_id = img_id.cuda() if use_gpu else img_id.cpu() + img_id = ( + img_id.to("mps") + if torch.backends.mps.is_available() + else "cuda" + if use_gpu + else img_id.cpu() + ) # create latent id img_id_downsample = F.interpolate(img_id, size=(112, 112)) diff --git a/src/dot/ui/ui.py b/src/dot/ui/ui.py index 044379a..35e2fcc 100644 --- a/src/dot/ui/ui.py +++ b/src/dot/ui/ui.py @@ -5,7 +5,9 @@ """ import os +import sys import tkinter +from pathlib import Path import click import customtkinter @@ -78,7 +80,7 @@ def __init__(self, *args, **kwargs): config_file (str): Path to the configuration file for the deepfake.\n swap_type (str): The type of swap to run.\n gpen_type (str, optional): The type of gpen model to use. Defaults to None.\n - gpen_path (str, optional): The path to the gpen models. Defaults to "./saved_models/gpen".\n + gpen_path (str, optional): The path to the gpen models. Defaults to "saved_models/gpen".\n show_fps (bool, optional): Pass flag to show fps value. Defaults to False.\n use_gpu (bool, optional): Pass flag to use GPU else use CPU. Defaults to False.\n head_pose (bool): Estimates head pose before swap. Used by fomm.\n @@ -594,6 +596,16 @@ def __init__(self): row=3, column=0, columnspan=3, padx=(180, 0), pady=(0, 20), sticky="nsew" ) + self.resources_path = "" + + # MacOS bundle has different resource directory structure + if sys.platform == "darwin": + if getattr(sys, "frozen", False): + self.resources_path = os.path.join( + str(Path(sys.executable).resolve().parents[0]).replace("MacOS", ""), + "Resources", + ) + def CreateToolTip(self, widget, text): toolTip = ToolTip(widget) @@ -719,19 +731,18 @@ def optionmenu_callback(self, choice: str): choice (str): The type of swap to run. """ - entry_list = [ - "source", - "target", + entry_list = ["source", "target", "crop_size"] + radio_list = ["swap_type", "gpen_type"] + model_list = [ "model_path", "parsing_model_path", "arcface_model_path", "checkpoints_dir", "gpen_path", - "crop_size", ] - radio_list = ["swap_type"] - config_file = f"./configs/{choice}.yaml" + config_file = os.path.join(self.resources_path, f"configs/{choice}.yaml") + if os.path.isfile(config_file): config = {} with open(config_file) as f: @@ -741,8 +752,16 @@ def optionmenu_callback(self, choice: str): if key in entry_list: self.modify_entry(eval(f"self.{key}"), config[key]) elif key in radio_list: - self.swap_type_radio_var = tkinter.StringVar(value=config[key]) + if key == "swap_type": + self.swap_type_radio_var = tkinter.StringVar(value=config[key]) + elif key == "gpen_type": + self.gpen_type_radio_var = tkinter.StringVar(value=config[key]) eval(f"self.{config[key]}_radio_button").invoke() + elif key in model_list: + self.modify_entry( + eval(f"self.{key}"), + os.path.join(self.resources_path, config[key]), + ) for entry in entry_list: if entry not in ["source", "target"]: @@ -781,7 +800,7 @@ def start_button_event(self, error_label): ), gpen_type=config.get("gpen_type", self.gpen_type_radio_var.get()), gpen_path=config.get( - "gpen_path", self.gpen_path.get() or "./saved_models/gpen" + "gpen_path", self.gpen_path.get() or "saved_models/gpen" ), crop_size=config.get( "crop_size",