Skip to content

Commit

Permalink
* Fix an issue when providers setting is set to CPU
Browse files Browse the repository at this point in the history
* Lay the groundwork for supporting new providers.
* Optimize CLIPSwitch inference to work entirely on CUDA tensors or CPU tensors.
  • Loading branch information
Alucard24 committed Sep 23, 2024
1 parent da0660b commit 1e92736
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 367 deletions.
27 changes: 10 additions & 17 deletions rope/Coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import time
import torch
from torchvision import transforms

import rope.GUI as GUI
import rope.VideoManager as VM
import rope.Models as Models
from rope.external.clipseg import CLIPDensePredT
from rope.Dicts import DEFAULT_DATA

resize_delay = 1
mem_delay = 1
Expand Down Expand Up @@ -104,10 +103,6 @@ def coordinator():
action.pop(0)

elif action [0][0] == "parameters":
if action[0][1]["CLIPSwitch"]:
if not vm.clip_session:
vm.clip_session = load_clip_model()

vm.parameters = action[0][1]
action.pop(0)

Expand Down Expand Up @@ -196,20 +191,18 @@ def coordinator():
gui.after(1, coordinator)
# print(time.time() - start)

def load_clip_model():
# https://github.com/timojl/clipseg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_session = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
# clip_session = CLIPDensePredTMasked(version='ViT-B/16', reduce_dim=64)
clip_session.eval();
clip_session.load_state_dict(torch.load('./models/rd64-uni-refined.pth'), strict=False)
clip_session.to(device)
return clip_session

def run():
global gui, vm, action, frame, r_frame, resize_delay, mem_delay

models = Models.Models()
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
DEFAULT_DATA['ProvidersPriorityTextSelMode'] = 'CUDA'
DEFAULT_DATA['ProvidersPriorityTextSelModes'] = ['CUDA', 'TensorRT', 'TensorRT-Engine', 'CPU']
else:
DEFAULT_DATA['ProvidersPriorityTextSelMode'] = 'CPU'
DEFAULT_DATA['ProvidersPriorityTextSelModes'] = ['CPU']

models = Models.Models(device=device)
gui = GUI.GUI(models)
vm = VM.VideoManager(models)

Expand Down
20 changes: 13 additions & 7 deletions rope/DFMModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
onnxruntime.log_verbosity_level = -1

class DFMModel:
def __init__(self, model_path: str, providers, device=None):
def __init__(self, model_path: str, providers, device='cuda'):

self._model_path = model_path
self.providers = providers
self.device = device
self.syncvec = torch.empty((1, 1), dtype=torch.float32, device=device)

sess = self._sess = onnxruntime.InferenceSession(str(model_path), providers=self.providers)
inputs = sess.get_inputs()

Expand Down Expand Up @@ -84,12 +87,12 @@ def convert(self, img, morph_factor=0.75, rct=False):
io_binding = self._sess.io_binding()

# Bind input image tensor
io_binding.bind_input(name='in_face:0', device_type='cuda', device_id=0, element_type=np.float32, shape=img.shape, buffer_ptr=img.data_ptr())
io_binding.bind_input(name='in_face:0', device_type=self.device, device_id=0, element_type=np.float32, shape=img.shape, buffer_ptr=img.data_ptr())

# Bind morph factor if the model supports it
if self._model_type == 2:
morph_factor_t = torch.tensor([morph_factor], dtype=torch.float32, device='cuda')
io_binding.bind_input(name='morph_value:0', device_type='cuda', device_id=0, element_type=np.float32, shape=morph_factor_t.shape, buffer_ptr=morph_factor_t.data_ptr())
morph_factor_t = torch.tensor([morph_factor], dtype=torch.float32, device=self.device)
io_binding.bind_input(name='morph_value:0', device_type=self.device, device_id=0, element_type=np.float32, shape=morph_factor_t.shape, buffer_ptr=morph_factor_t.data_ptr())

# Prepare output tensors and bind them
outputs = self._sess.get_outputs()
Expand All @@ -101,23 +104,26 @@ def convert(self, img, morph_factor=0.75, rct=False):

# Create a torch tensor with the shape and dtype of the output
torch_dtype = self.onnx_to_torch_dtype[output.type]
tensor_output = torch.empty(shape, dtype=torch_dtype, device='cuda').contiguous()
tensor_output = torch.empty(shape, dtype=torch_dtype, device=self.device).contiguous()

# Append the tensor to the list
binding_outputs.append(tensor_output)

# Bind the output using ONNX Runtime's io_binding
io_binding.bind_output(
name=output.name,
device_type='cuda',
device_type=self.device,
device_id=0,
element_type=self.onnx_to_numpy_dtype[output.type], # Use NumPy dtype for element_type
shape=shape,
buffer_ptr=binding_outputs[idx].data_ptr()
)

# Run the model
torch.cuda.synchronize()
if self.device == "cuda":
torch.cuda.synchronize()
elif self.device != "cpu":
self.syncvec.cpu()
self._sess.run_with_iobinding(io_binding)

# Process outputs (resize, clip channels, and convert back to original dtype)
Expand Down
13 changes: 10 additions & 3 deletions rope/EngineBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
import logging
import platform
import ctypes
import tensorrt as trt
import numpy as np
from pathlib import Path

try:
import tensorrt as trt
except ModuleNotFoundError:
pass

logging.basicConfig(level=logging.INFO)
logging.getLogger("EngineBuilder").setLevel(logging.INFO)
log = logging.getLogger("EngineBuilder")

# Creazione di un'istanza globale di logger di TensorRT
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
if 'trt' in globals():
# Creazione di un'istanza globale di logger di TensorRT
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
else:
TRT_LOGGER = {}

# imported from https://github.com/warmshao/FasterLivePortrait/blob/master/scripts/onnx2trt.py
# adjusted to work with TensorRT 10.3.0
Expand Down
8 changes: 5 additions & 3 deletions rope/FaceUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
[41.5493, 92.3655],
[70.7299, 92.2041]],
dtype=torch.float32,
device='cuda'

) # Shape: (5, 2)
if torch.cuda.is_available():
arcface_src_cuda = arcface_src_cuda.to('cuda')

def pad_image_by_size(img, image_size):
# Se image_size non è una tupla, crea una tupla con altezza e larghezza uguali
Expand Down Expand Up @@ -300,7 +302,7 @@ def align_crop(img, lmk, image_size, mode='arcfacemap', interpolation=v2.Interpo
borderMode=cv2.BORDER_REPLICATE,
)
'''
warped = warp_affine_torchvision(img, matrix, (image_size, image_size), rotation_ratio=57.2958, border_value=0.0, border_mode='replicate', interpolation_value=v2.functional.InterpolationMode.NEAREST, device='cuda')
warped = warp_affine_torchvision(img, matrix, (image_size, image_size), rotation_ratio=57.2958, border_value=0.0, border_mode='replicate', interpolation_value=v2.functional.InterpolationMode.NEAREST, device=img.device)

return warped, matrix

Expand Down Expand Up @@ -460,7 +462,7 @@ def warp_face_by_bounding_box_for_landmark_68(img, bbox, input_size):
if torch.mean(crop_image.to(dtype=torch.float32)[0, :, :]) < 30:
crop_image = cv2.cvtColor(crop_image.permute(1, 2, 0).to('cpu').numpy(), cv2.COLOR_RGB2Lab)
crop_image[:, :, 0] = cv2.createCLAHE(clipLimit = 2).apply(crop_image[:, :, 0])
crop_image = torch.from_numpy(cv2.cvtColor(crop_image, cv2.COLOR_Lab2RGB)).to('cuda').permute(2, 0, 1)
crop_image = torch.from_numpy(cv2.cvtColor(crop_image, cv2.COLOR_Lab2RGB)).to(img.device).permute(2, 0, 1)

return crop_image, affine_matrix

Expand Down
9 changes: 3 additions & 6 deletions rope/GUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,13 +2109,13 @@ def load_input_faces(self):
img = cv2.imread(file)

if img is not None:
img = torch.from_numpy(img.astype('uint8')).to('cuda')
img = torch.from_numpy(img.astype('uint8')).to(self.models.device)

pad_scale = 0.2
padded_width = int(img.size()[1]*(1.+pad_scale))
padded_height = int(img.size()[0]*(1.+pad_scale))

padding = torch.zeros((padded_height, padded_width, 3), dtype=torch.uint8, device='cuda:0')
padding = torch.zeros((padded_height, padded_width, 3), dtype=torch.uint8, device=self.models.device)

width_start = int(img.size()[1]*pad_scale/2)
width_end = width_start+int(img.size()[1])
Expand Down Expand Up @@ -2164,7 +2164,7 @@ def load_input_faces(self):

def find_faces(self):
try:
img = torch.from_numpy(self.video_image).to('cuda')
img = torch.from_numpy(self.video_image).to(self.models.device)
img = img.permute(2,0,1)
if self.parameters["AutoRotationSwitch"]:
rotation_angles = [0, 90, 180, 270]
Expand Down Expand Up @@ -2362,9 +2362,6 @@ def select_input_faces(self, event, button):
self.add_action("target_faces", self.target_faces)
self.add_action('get_requested_video_frame', self.video_slider.get())

# latent = torch.from_numpy(self.models.calc_swapper_latent(self.source_faces[button]['Embedding'])).float().to('cuda')
# face['ptrdata'] = self.models.run_swap_stg1(latent)

def populate_target_videos(self):
videos = []
#Webcam setup
Expand Down
Loading

0 comments on commit 1e92736

Please sign in to comment.