diff --git a/gfpgan/utils.py b/gfpgan/utils.py index 74ee5a83..f8971d1b 100644 --- a/gfpgan/utils.py +++ b/gfpgan/utils.py @@ -76,6 +76,7 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg from gfpgan.archs.restoreformer_arch import RestoreFormer self.gfpgan = RestoreFormer() # initialize face helper + model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights') self.face_helper = FaceRestoreHelper( upscale, face_size=512, @@ -84,11 +85,11 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg save_ext='png', use_parse=True, device=self.device, - model_rootpath='gfpgan/weights') + model_rootpath=model_dir) if model_path.startswith('https://'): model_path = load_file_from_url( - url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) + url=model_path, model_dir=model_dir, progress=True, file_name=None) loadnet = torch.load(model_path) if 'params_ema' in loadnet: keyname = 'params_ema'