Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Keep results on CPU to avoid OOM #148

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions audio_separator/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_model(self):
raise ValueError("Unknown Roformer model type in the configuration.")

# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location="cpu")
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
self.model_run.load_state_dict(checkpoint)
self.model_run.to(self.torch_device).eval()
Expand Down Expand Up @@ -198,8 +198,6 @@ def overlap_add(self, result, x, weights, start, length):
"""
Adds the overlapping part of the result to the result tensor.
"""
x = x.to(result.device)
weights = weights.to(result.device)
result[..., start : start + length] += x[..., :length] * weights[:length]
return result

Expand Down Expand Up @@ -246,13 +244,11 @@ def demix(self, mix: np.ndarray) -> dict:

device = next(self.model_run.parameters()).device

# Transfer to the weighting plate for the same device as the other tensors
window = window.to(device)

with torch.no_grad():
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32).to(device)
counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)

for i in tqdm(range(0, mix.shape[1], step)):
part = mix[:, i : i + chunk_size]
Expand All @@ -262,8 +258,10 @@ def demix(self, mix: np.ndarray) -> dict:
length = chunk_size
part = part.to(device)
x = self.model_run(part.unsqueeze(0))[0]
x = x.cpu()
# Perform overlap_add on CPU
if i + chunk_size > mix.shape[1]:
# Corrigido para adicionar corretamente ao final do tensor
# Fixed to correctly add to the end of the tensor
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
counter[..., result.shape[-1] - chunk_size :] += window[:length]
else:
Expand Down Expand Up @@ -311,7 +309,6 @@ def demix(self, mix: np.ndarray) -> dict:
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
accumulated_outputs = accumulated_outputs.to(self.torch_device)

with torch.no_grad():
count = 0
Expand All @@ -324,7 +321,9 @@ def demix(self, mix: np.ndarray) -> dict:
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
for individual_output in single_batch_result:
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output
individual_output_cpu = individual_output.cpu()
# Accumulate outputs on CPU
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
count += 1

self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
Expand Down
Loading