Skip to content

feat: support PyTorch mps backend on macOS #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

*.pkl
output/*
train_log/*
# train_log/*
*.mp4

test/
.idea/
*.npz

*.zip

.DS_Store
2 changes: 1 addition & 1 deletion inference_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
Expand Down
2 changes: 1 addition & 1 deletion inference_img_SR.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
Expand Down
2 changes: 1 addition & 1 deletion inference_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def transferAudio(sourceVideo, targetVideo):
if not args.img is None:
args.png = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
Expand Down
2 changes: 1 addition & 1 deletion inference_video_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def transferAudio(sourceVideo, targetVideo):
if not args.img is None:
args.png = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
Expand Down
2 changes: 1 addition & 1 deletion model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")


class EPE(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion model/pytorch_msssim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from math import exp
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
Expand Down
12 changes: 10 additions & 2 deletions model/warplayer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
backwarp_tenGrid = {}


Expand All @@ -19,4 +19,12 @@ def warp(tenInput, tenFlow):
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)

g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)

pd = 'border'

# mps does not support 'border' padding mode, use 'zero' instead
if tenInput.device.type == "mps":
pd = 'zeros'
g = g.clamp(-1, 1)

return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode=pd, align_corners=True)
169 changes: 169 additions & 0 deletions train_log/IFNet_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
# from train_log.refine import *

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.LeakyReLU(0.2, True)
)

def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True)
)

class Head(nn.Module):
def __init__(self):
super(Head, self).__init__()
self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(16, 16, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3

class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super(ResConv, self).__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1\
)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)

class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4*13, 4, 2, 1),
nn.PixelShuffle(2)
)

def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
feat = tmp[:, 5:]
return flow, mask, feat

class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+32, c=192)
self.block1 = IFBlock(8+4+8+32, c=128)
self.block2 = IFBlock(8+4+8+32, c=96)
self.block3 = IFBlock(8+4+8+32, c=64)
self.block4 = IFBlock(8+4+8+32, c=32)
self.encode = Head()

# not used during inference
'''
self.teacher = IFBlock(8+4+8+3+32, c=64)
self.caltime = nn.Sequential(
nn.Conv2d(32+9, 8, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 64, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 1, 3, 1, 1),
nn.Sigmoid()
)
'''

def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False):
if training == False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
loss_cons = 0
block = [self.block0, self.block1, self.block2, self.block3, self.block4]
for i in range(5):
if flow is None:
flow, mask, feat = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i])
if ensemble:
print("warning: ensemble is not supported since RIFEv4.21")
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0, feat = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask, feat), 1), flow, scale=scale_list[i])
if ensemble:
print("warning: ensemble is not supported since RIFEv4.21")
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
merged[4] = (warped_img0 * mask + warped_img1 * (1 - mask))
if not fastmode:
print('contextnet is removed')
'''
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
merged[4] = torch.clamp(merged[4] + res, 0, 1)
'''
return flow_list, mask_list[4], merged
89 changes: 89 additions & 0 deletions train_log/RIFE_HDv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from train_log.IFNet_HDv3 import *
import torch.nn.functional as F
from model.loss import *

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
self.version = 4.25
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)

def train(self):
self.flownet.train()

def eval(self):
self.flownet.eval()

def device(self):
self.flownet.to(device)

def load_model(self, path, rank=0):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path))), False)
else:
self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path), map_location ='cpu')), False)

def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))

def inference(self, img0, img1, timestep=0.5, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [16/scale, 8/scale, 4/scale, 2/scale, 1/scale]
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
return merged[-1]

def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
scale = [16, 8, 4, 2, 1]
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
loss_l1 = (merged[-1] - gt).abs().mean()
loss_smooth = self.sobel(flow[-1], flow[-1]*0).mean()
# loss_vgg = self.vgg(merged[-1], gt)
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_smooth * 0.1
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
return merged[-1], {
'mask': mask,
'flow': flow[-1][:, :2],
'loss_l1': loss_l1,
'loss_cons': loss_cons,
'loss_smooth': loss_smooth,
}
Loading