Skip to content

Commit

Permalink
Apple M1/M2 Compatibility (#115)
Browse files Browse the repository at this point in the history
* 👷 metal performance shader

* update readme

* added m2 requirements

* bump up torch and torchvision version

* 👷 fix issue with paths and built bundle

* 🔗 dot-macos.zip link

* construction_worker: update README instructions

* Update README.md

* remove comment

---------

Co-authored-by: Giorgio Patrini <[email protected]>
  • Loading branch information
vassilispapadop and giorgiop authored Oct 3, 2023
1 parent 09830a7 commit 06c764e
Show file tree
Hide file tree
Showing 21 changed files with 286 additions and 64 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ Download and run the dot executable for your OS:
- Download `dot.zip` from [here](https://drive.google.com/file/d/10fdnSszaEbpGdCKxxeBFXQrkagxe3-RT/view), unzip it and then run `dot.exe`
- Ubuntu:
- ToDo
- Mac:
- ToDo
- Mac (Tested on Apple M2 Sonoma 14.0):

- Download `dot-executable.app` from [here](https://drive.google.com/drive/folders/1n22mvWSFmXbSAspZWp5sChbCqXgOEKGx?usp=drive_link)
- Run `dot-executable.app`
- In case of camera reading error:
- Right click and choose `Show Package Contents`
- Execute `dot-executable` from `Contents/MacOS` folder

#### GUI Usage

Expand Down Expand Up @@ -118,6 +123,15 @@ Install the `torch` and `torchvision` dependencies based on the CUDA version ins
To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.cuda.is_available())"`. If the output is `True`, the dependencies are installed with CUDA support.
###### With MPS Support(Apple Silicon)
```bash
conda env create -f envs/environment-apple-m2.yaml
conda activate dot
```
To check that `torch` and `torchvision` are installed correctly, run the following command: `python -c "import torch; print(torch.backends.mps.is_available())"`. If the output is `True`, the dependencies are installed with Metal programming framework support.
###### With CPU Support (slow, not recommended)
```bash
Expand Down
2 changes: 1 addition & 1 deletion configs/faceswap_cv2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
---
swap_type: faceswap_cv2
model_path: ./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat
model_path: saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat
2 changes: 1 addition & 1 deletion configs/fomm.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
---
swap_type: fomm
model_path: ./saved_models/fomm/vox-adv-cpk.pth.tar
model_path: saved_models/fomm/vox-adv-cpk.pth.tar
head_pose: true
6 changes: 3 additions & 3 deletions configs/simswap.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
swap_type: simswap
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: ./saved_models/simswap/checkpoints
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: saved_models/simswap/checkpoints
6 changes: 3 additions & 3 deletions configs/simswaphq.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
swap_type: simswap
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: ./saved_models/simswap/checkpoints
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
checkpoints_dir: saved_models/simswap/checkpoints
crop_size: 512
10 changes: 10 additions & 0 deletions envs/environment-apple-m2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
name: dot
channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- pip=21.3
- pip:
- -r ../requirements-apple-m2.txt
134 changes: 134 additions & 0 deletions requirements-apple-m2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile setup.cfg
#
absl-py==1.1.0
# via mediapipe
attrs==21.4.0
# via mediapipe
certifi==2022.12.7
# via requests
chardet==4.0.0
# via requests
click==8.0.2
# via dot (setup.cfg)
cycler==0.11.0
# via matplotlib
dlib==19.19.0
# via dot (setup.cfg)
face-alignment==1.3.3
# via dot (setup.cfg)
flatbuffers==2.0
# via onnxruntime
fonttools==4.34.4
# via matplotlib
idna==2.10
# via requests
imageio==2.19.3
# via scikit-image
kiwisolver==1.4.3
# via matplotlib
kornia==0.6.5
# via dot (setup.cfg)
llvmlite==0.38.1
# via numba
matplotlib==3.5.2
# via mediapipe
mediapipe-silicon
# via dot (setup.cfg)
mediapipe==0.10.3
networkx==2.8.4
# via scikit-image
numba==0.55.2
# via face-alignment
numpy==1.22.0
# via
# dot (setup.cfg)
# face-alignment
# imageio
# matplotlib
# mediapipe
# numba
# onnxruntime
# opencv-contrib-python
# opencv-python
# pywavelets
# scikit-image
# scipy
# tifffile
# torchvision
onnxruntime==1.15.1
# via dot (setup.cfg)
opencv-contrib-python==4.5.5.62
# via
# dot (setup.cfg)
# mediapipe
opencv-python==4.5.5.62
# via
# dot (setup.cfg)
# face-alignment
packaging==21.3
# via
# kornia
# matplotlib
# scikit-image
pillow==9.3.0
# via
# dot (setup.cfg)
# imageio
# matplotlib
# scikit-image
# torchvision
protobuf==3.20.2
# via
# dot (setup.cfg)
# mediapipe
# onnxruntime
pyparsing==3.0.9
# via
# matplotlib
# packaging
python-dateutil==2.8.2
# via matplotlib
pywavelets==1.3.0
# via scikit-image
pyyaml==5.4.1
# via dot (setup.cfg)
requests==2.25.1
# via dot (setup.cfg)
scikit-image==0.19.1
# via
# dot (setup.cfg)
# face-alignment
scipy==1.10.1
# via
# dot (setup.cfg)
# face-alignment
# scikit-image
six==1.16.0
# via
# mediapipe
# python-dateutil
tifffile==2022.5.4
# via scikit-image
torch==2.0.1
# via
# dot (setup.cfg)
# face-alignment
# kornia
# torchvision
torchvision==0.15.2
# via dot (setup.cfg)
tqdm==4.64.0
# via face-alignment
typing-extensions==4.3.0
# via torch
urllib3==1.26.10
# via requests
wheel==0.38.1
# via mediapipe

# The following packages are considered to be unsafe in a requirements file:
# setuptools
8 changes: 4 additions & 4 deletions src/dot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(
arcface_model_path: str = None,
checkpoints_dir: str = None,
gpen_type: str = None,
gpen_path: str = "./saved_models/gpen",
gpen_path: str = "saved_models/gpen",
crop_size: int = 224,
head_pose: bool = False,
save_folder: str = None,
Expand All @@ -42,7 +42,7 @@ def run(
arcface_model_path (str, optional): The path to the arcface model. Defaults to None.
checkpoints_dir (str, optional): The path to the checkpoints directory. Defaults to None.
gpen_type (str, optional): The type of gpen model to use. Defaults to None.
gpen_path (str, optional): The path to the gpen models. Defaults to "./saved_models/gpen".
gpen_path (str, optional): The path to the gpen models. Defaults to "saved_models/gpen".
crop_size (int, optional): The size to crop the images to. Defaults to 224.
save_folder (str, optional): The path to the save folder. Defaults to None.
show_fps (bool, optional): Pass flag to show fps value. Defaults to False.
Expand Down Expand Up @@ -130,7 +130,7 @@ def run(
@click.option(
"--gpen_path",
"gpen_path",
default="./saved_models/gpen",
default="saved_models/gpen",
help="Path to gpen models.",
)
@click.option("--crop_size", "crop_size", type=int, default=224)
Expand Down Expand Up @@ -185,7 +185,7 @@ def main(
arcface_model_path: str = None,
checkpoints_dir: str = None,
gpen_type: str = None,
gpen_path: str = "./saved_models/gpen",
gpen_path: str = "saved_models/gpen",
crop_size: int = 224,
save_folder: str = None,
show_fps: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/dot/commons/model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModelOption(ABC):
def __init__(
self,
gpen_type=None,
gpen_path="./saved_models/gpen",
gpen_path="saved_models/gpen",
use_gpu=True,
crop_size=256,
):
Expand Down
4 changes: 1 addition & 3 deletions src/dot/faceswap_cv2/swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
)

# define globals
CACHED_PREDICTOR_PATH = (
"./saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat"
)
CACHED_PREDICTOR_PATH = "saved_models/faceswap_cv/shape_predictor_68_face_landmarks.dat"


class Swap:
Expand Down
11 changes: 8 additions & 3 deletions src/dot/gpen/face_model/face_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(
self.n_mlp = 8
self.is_norm = is_norm
self.resolution = size
self.device = (
("mps" if torch.backends.mps.is_available() else "cuda")
if use_gpu
else "cpu"
)
self.load_model(
channel_multiplier=channel_multiplier, narrow=narrow, use_gpu=use_gpu
)
Expand All @@ -36,8 +41,8 @@ def load_model(self, channel_multiplier=2, narrow=1, use_gpu=True):
if use_gpu:
self.model = FullGenerator(
self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow
).cuda()
pretrained_dict = torch.load(self.mfile)
).to(self.device)
pretrained_dict = torch.load(self.mfile, map_location=self.device)
else:
self.model = FullGenerator(
self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow
Expand All @@ -60,7 +65,7 @@ def process(self, img, use_gpu=True):

def img2tensor(self, img, use_gpu=True):
if use_gpu:
img_t = torch.from_numpy(img).cuda() / 255.0
img_t = torch.from_numpy(img).to(self.device) / 255.0
else:
img_t = torch.from_numpy(img).cpu() / 255.0
if self.is_norm:
Expand Down
10 changes: 6 additions & 4 deletions src/dot/gpen/retinaface/layers/modules/multibox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,14 @@ def forward(self, predictions, priors, targets):
landm_t,
idx,
)
device = "cpu"
if GPU:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
landm_t = landm_t.cuda()
device = "mps" if torch.backends.mps.is_available() else "cuda"
loc_t = loc_t.to(device)
conf_t = conf_t.to(device)
landm_t = landm_t.to(device)

zeros = torch.tensor(0).cuda()
zeros = torch.tensor(0).to(device)
# landm Loss (Smooth L1)
# Shape: [batch,num_priors,10]
pos1 = conf_t > zeros
Expand Down
19 changes: 10 additions & 9 deletions src/dot/gpen/retinaface/retinaface_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True):
cudnn.benchmark = True
self.pretrained_path = os.path.join(base_dir, "weights", network + ".pth")
if use_gpu:
self.device = torch.cuda.current_device()
self.device = "mps" if torch.backends.mps.is_available() else "cuda"
else:
self.device = "cpu"
self.cfg = cfg_re50
self.net = RetinaFace(cfg=self.cfg, phase="test")
if use_gpu:
self.load_model()
self.net = self.net.cuda()
self.net = self.net.to(self.device)
else:
self.load_model(load_to_cpu=True)
self.net = self.net.cpu()
Expand All @@ -57,9 +57,10 @@ def load_model(self, load_to_cpu=False):
self.pretrained_path, map_location=lambda storage, loc: storage
)
else:
pretrained_dict = torch.load(
self.pretrained_path, map_location=lambda storage, loc: storage.cuda()
)
# pretrained_dict = torch.load(
# self.pretrained_path, map_location=lambda storage, loc: storage.to("mps")#.cuda()
# )
pretrained_dict = torch.load(self.pretrained_path, map_location=self.device)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = self.remove_prefix(
pretrained_dict["state_dict"], "module."
Expand Down Expand Up @@ -89,8 +90,8 @@ def detect(
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
if use_gpu:
img = img.cuda()
scale = scale.cuda()
img = img.to(self.device)
scale = scale.to(self.device)
else:
img = img.cpu()
scale = scale.cpu()
Expand All @@ -100,7 +101,7 @@ def detect(
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
if use_gpu:
priors = priors.cuda()
priors = priors.to(self.device)
else:
priors = priors.cpu()

Expand All @@ -125,7 +126,7 @@ def detect(
]
)
if use_gpu:
scale1 = scale1.cuda()
scale1 = scale1.to(self.device)
else:
scale1 = scale1.cpu()

Expand Down
8 changes: 4 additions & 4 deletions src/dot/simswap/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
---
analysis:
simswap:
parsing_model_path: ./saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
checkpoints_dir: ./saved_models/simswap/checkpoints
arcface_model_path: ./saved_models/simswap/arcface_model/arcface_checkpoint.tar
parsing_model_path: saved_models/simswap/parsing_model/checkpoint/79999_iter.pth
checkpoints_dir: saved_models/simswap/checkpoints
arcface_model_path: saved_models/simswap/arcface_model/arcface_checkpoint.tar
detection_threshold: 0.6
det_size: [640, 640]
use_gpu: true
Expand All @@ -19,4 +19,4 @@ analysis:
opt_which_epoch: latest
opt_continue_train: store_true
gpen: gpen_256
gpen_path: ./saved_models/gpen
gpen_path: saved_models/gpen
Loading

0 comments on commit 06c764e

Please sign in to comment.