Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple M1/M2 Compatibility #115

Merged
merged 10 commits into from
Oct 3, 2023
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ Download and run the dot executable for your OS:
- Download `dot.zip` from [here](https://drive.google.com/file/d/10fdnSszaEbpGdCKxxeBFXQrkagxe3-RT/view), unzip it and then run `dot.exe`
- Ubuntu:
- ToDo
- Mac:
- ToDo
- Mac (Tested on Apple M2 Sonoma 14.0):

- Download `dot-executable.app` from [here](https://drive.google.com/drive/folders/1n22mvWSFmXbSAspZWp5sChbCqXgOEKGx?usp=drive_link)
- Run `dot-executable.app`
- In case of camera reading error:
- Right click and choose `Show Package Contents`
- Execute `dot-executable` from `Contents/MacOS` folder

#### GUI Usage

Expand Down Expand Up @@ -118,6 +123,15 @@ Install the `torch` and `torchvision` dependencies based on the CUDA version ins

To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.cuda.is_available())"`. If the output is `True`, the dependencies are installed with CUDA support.

###### With MPS Support(Apple Silicon)

```bash
conda env create -f envs/environment-apple-m2.yaml
conda activate dot
```

To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.backends.mps.is_available())"`. If the output is `True`, the dependencies are installed with Metal programming framework support.

###### With CPU Support (slow, not recommended)

```bash
Expand Down
2 changes: 1 addition & 1 deletion configs/faceswap_cv2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
---
swap_type: faceswap_cv2
model_path: ./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat
model_path: saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat
2 changes: 1 addition & 1 deletion configs/fomm.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
---
swap_type: fomm
model_path: ./saved_models/fomm/vox-adv-cpk.pth.tar
model_path: saved_models/fomm/vox-adv-cpk.pth.tar
head_pose: true
6 changes: 3 additions & 3 deletions configs/simswap.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
swap_type: simswap
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: ./saved_models/simswap/checkpoints
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: saved_models/simswap/checkpoints
6 changes: 3 additions & 3 deletions configs/simswaphq.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
swap_type: simswap
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: ./saved_models/simswap/checkpoints
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: saved_models/simswap/checkpoints
crop_size: 512
10 changes: 10 additions & 0 deletions envs/environment-apple-m2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
name: dot
channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- pip=21.3
- pip:
- -r ../requirements-apple-m2.txt
134 changes: 134 additions & 0 deletions requirements-apple-m2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile setup.cfg
#
absl-py==1.1.0
# via mediapipe
attrs==21.4.0
# via mediapipe
certifi==2022.12.7
# via requests
chardet==4.0.0
# via requests
click==8.0.2
# via dot (setup.cfg)
cycler==0.11.0
# via matplotlib
dlib==19.19.0
# via dot (setup.cfg)
face-alignment==1.3.3
# via dot (setup.cfg)
flatbuffers==2.0
# via onnxruntime
fonttools==4.34.4
# via matplotlib
idna==2.10
# via requests
imageio==2.19.3
# via scikit-image
kiwisolver==1.4.3
# via matplotlib
kornia==0.6.5
# via dot (setup.cfg)
llvmlite==0.38.1
# via numba
matplotlib==3.5.2
# via mediapipe
mediapipe-silicon
# via dot (setup.cfg)
mediapipe==0.10.3
networkx==2.8.4
# via scikit-image
numba==0.55.2
# via face-alignment
numpy==1.22.0
# via
# dot (setup.cfg)
# face-alignment
# imageio
# matplotlib
# mediapipe
# numba
# onnxruntime
# opencv-contrib-python
# opencv-python
# pywavelets
# scikit-image
# scipy
# tifffile
# torchvision
onnxruntime==1.15.1
# via dot (setup.cfg)
opencv-contrib-python==4.5.5.62
# via
# dot (setup.cfg)
# mediapipe
opencv-python==4.5.5.62
# via
# dot (setup.cfg)
# face-alignment
packaging==21.3
# via
# kornia
# matplotlib
# scikit-image
pillow==9.3.0
# via
# dot (setup.cfg)
# imageio
# matplotlib
# scikit-image
# torchvision
protobuf==3.20.2
# via
# dot (setup.cfg)
# mediapipe
# onnxruntime
pyparsing==3.0.9
# via
# matplotlib
# packaging
python-dateutil==2.8.2
# via matplotlib
pywavelets==1.3.0
# via scikit-image
pyyaml==5.4.1
# via dot (setup.cfg)
requests==2.25.1
# via dot (setup.cfg)
scikit-image==0.19.1
# via
# dot (setup.cfg)
# face-alignment
scipy==1.10.1
# via
# dot (setup.cfg)
# face-alignment
# scikit-image
six==1.16.0
# via
# mediapipe
# python-dateutil
tifffile==2022.5.4
# via scikit-image
torch==2.0.1
# via
# dot (setup.cfg)
# face-alignment
# kornia
# torchvision
torchvision==0.15.2
# via dot (setup.cfg)
tqdm==4.64.0
# via face-alignment
typing-extensions==4.3.0
# via torch
urllib3==1.26.10
# via requests
wheel==0.38.1
# via mediapipe

# The following packages are considered to be unsafe in a requirements file:
# setuptools
8 changes: 4 additions & 4 deletions src/dot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(
arcface_model_path: str = None,
checkpoints_dir: str = None,
gpen_type: str = None,
gpen_path: str = "./saved_models/gpen",
gpen_path: str = "saved_models/gpen",
crop_size: int = 224,
head_pose: bool = False,
save_folder: str = None,
Expand All @@ -42,7 +42,7 @@ def run(
arcface_model_path (str, optional): The path to the arcface model. Defaults to None.
checkpoints_dir (str, optional): The path to the checkpoints directory. Defaults to None.
gpen_type (str, optional): The type of gpen model to use. Defaults to None.
gpen_path (str, optional): The path to the gpen models. Defaults to "./saved_models/gpen".
gpen_path (str, optional): The path to the gpen models. Defaults to "saved_models/gpen".
crop_size (int, optional): The size to crop the images to. Defaults to 224.
save_folder (str, optional): The path to the save folder. Defaults to None.
show_fps (bool, optional): Pass flag to show fps value. Defaults to False.
Expand Down Expand Up @@ -130,7 +130,7 @@ def run(
@click.option(
"--gpen_path",
"gpen_path",
default="./saved_models/gpen",
default="saved_models/gpen",
help="Path to gpen models.",
)
@click.option("--crop_size", "crop_size", type=int, default=224)
Expand Down Expand Up @@ -185,7 +185,7 @@ def main(
arcface_model_path: str = None,
checkpoints_dir: str = None,
gpen_type: str = None,
gpen_path: str = "./saved_models/gpen",
gpen_path: str = "saved_models/gpen",
crop_size: int = 224,
save_folder: str = None,
show_fps: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/dot/commons/model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModelOption(ABC):
def __init__(
self,
gpen_type=None,
gpen_path="./saved_models/gpen",
gpen_path="saved_models/gpen",
use_gpu=True,
crop_size=256,
):
Expand Down
4 changes: 1 addition & 3 deletions src/dot/faceswap_cv2/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
)

# define globals
CACHED_PREDICTOR_PATH = (
"./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat"
)
CACHED_PREDICTOR_PATH = "saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat"


class Swap:
Expand Down
11 changes: 8 additions & 3 deletions src/dot/gpen/face_model/face_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(
self.n_mlp = 8
self.is_norm = is_norm
self.resolution = size
self.device = (
("mps" if torch.backends.mps.is_available() else "cuda")
if use_gpu
else "cpu"
)
self.load_model(
channel_multiplier=channel_multiplier, narrow=narrow, use_gpu=use_gpu
)
Expand All @@ -36,8 +41,8 @@ def load_model(self, channel_multiplier=2, narrow=1, use_gpu=True):
if use_gpu:
self.model = FullGenerator(
self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow
).cuda()
pretrained_dict = torch.load(self.mfile)
).to(self.device)
pretrained_dict = torch.load(self.mfile, map_location=self.device)
else:
self.model = FullGenerator(
self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow
Expand All @@ -60,7 +65,7 @@ def process(self, img, use_gpu=True):

def img2tensor(self, img, use_gpu=True):
if use_gpu:
img_t = torch.from_numpy(img).cuda() / 255.0
img_t = torch.from_numpy(img).to(self.device) / 255.0
else:
img_t = torch.from_numpy(img).cpu() / 255.0
if self.is_norm:
Expand Down
10 changes: 6 additions & 4 deletions src/dot/gpen/retinaface/layers/modules/multibox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,14 @@ def forward(self, predictions, priors, targets):
landm_t,
idx,
)
device = "cpu"
if GPU:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
landm_t = landm_t.cuda()
device = "mps" if torch.backends.mps.is_available() else "cuda"
loc_t = loc_t.to(device)
conf_t = conf_t.to(device)
landm_t = landm_t.to(device)

zeros = torch.tensor(0).cuda()
zeros = torch.tensor(0).to(device)
# landm Loss (Smooth L1)
# Shape: [batch,num_priors,10]
pos1 = conf_t > zeros
Expand Down
19 changes: 10 additions & 9 deletions src/dot/gpen/retinaface/retinaface_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True):
cudnn.benchmark = True
self.pretrained_path = os.path.join(base_dir, "weights", network + ".pth")
if use_gpu:
self.device = torch.cuda.current_device()
self.device = "mps" if torch.backends.mps.is_available() else "cuda"
else:
self.device = "cpu"
self.cfg = cfg_re50
self.net = RetinaFace(cfg=self.cfg, phase="test")
if use_gpu:
self.load_model()
self.net = self.net.cuda()
self.net = self.net.to(self.device)
else:
self.load_model(load_to_cpu=True)
self.net = self.net.cpu()
Expand All @@ -57,9 +57,10 @@ def load_model(self, load_to_cpu=False):
self.pretrained_path, map_location=lambda storage, loc: storage
)
else:
pretrained_dict = torch.load(
self.pretrained_path, map_location=lambda storage, loc: storage.cuda()
)
# pretrained_dict = torch.load(
# self.pretrained_path, map_location=lambda storage, loc: storage.to("mps")#.cuda()
# )
pretrained_dict = torch.load(self.pretrained_path, map_location=self.device)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = self.remove_prefix(
pretrained_dict["state_dict"], "module."
Expand Down Expand Up @@ -89,8 +90,8 @@ def detect(
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
if use_gpu:
img = img.cuda()
scale = scale.cuda()
img = img.to(self.device)
scale = scale.to(self.device)
else:
img = img.cpu()
scale = scale.cpu()
Expand All @@ -100,7 +101,7 @@ def detect(
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
if use_gpu:
priors = priors.cuda()
priors = priors.to(self.device)
else:
priors = priors.cpu()

Expand All @@ -125,7 +126,7 @@ def detect(
]
)
if use_gpu:
scale1 = scale1.cuda()
scale1 = scale1.to(self.device)
else:
scale1 = scale1.cpu()

Expand Down
8 changes: 4 additions & 4 deletions src/dot/simswap/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
analysis:
simswap:
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
checkpoints_dir: ./saved_models/simswap/checkpoints
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
checkpoints_dir: saved_models/simswap/checkpoints
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
detection_threshold: 0.6
det_size: [640, 640]
use_gpu: true
Expand All @@ -19,4 +19,4 @@ analysis:
opt_which_epoch: latest
opt_continue_train: store_true
gpen: gpen_256
gpen_path: ./saved_models/gpen
gpen_path: saved_models/gpen
Loading
Loading