From a581da750a61e5ab25f70cc026bd296df61944c8 Mon Sep 17 00:00:00 2001 From: Andrew Beveridge Date: Sat, 25 May 2024 00:48:26 -0400 Subject: [PATCH] Fixed GPU mapping for roformer models --- audio_separator/separator/architectures/mdxc_separator.py | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/audio_separator/separator/architectures/mdxc_separator.py b/audio_separator/separator/architectures/mdxc_separator.py index 505578a..fe2efa4 100644 --- a/audio_separator/separator/architectures/mdxc_separator.py +++ b/audio_separator/separator/architectures/mdxc_separator.py @@ -94,7 +94,7 @@ def load_model(self): checkpoint = torch.load(self.model_path, map_location="cpu") self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module self.model_run.load_state_dict(checkpoint) - self.model_run.to(self.torch_device_cpu).eval() + self.model_run.to(self.torch_device).eval() else: self.logger.debug("Loading TFC_TDF_net model...") @@ -183,8 +183,8 @@ def overlap_add(self, result, x, weights, start, length): """ Adds the overlapping part of the result to the result tensor. """ - if self.torch_device == "mps": - x = x.to(self.torch_device_cpu) + x = x.to(result.device) + weights = weights.to(result.device) result[..., start : start + length] += x[..., :length] * weights[:length] return result diff --git a/pyproject.toml b/pyproject.toml index 031afdd..c816818 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "audio-separator" -version = "0.17.1" +version = "0.17.2" description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07" authors = ["Andrew Beveridge "] license = "MIT"