Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bebra777228 committed Dec 10, 2024
1 parent 8daef34 commit 1cc8c8d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
19 changes: 9 additions & 10 deletions PolUVR/separator/architectures/mdxc_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def load_model(self):
raise ValueError("Unknown Roformer model type in the configuration.")

# Load model checkpoint
checkpoint = torch.load(self.model_path, map_location="cpu")
checkpoint = torch.load(self.model_path, map_location="cpu", weights_only=True)
self.model_run = model if not isinstance(model, torch.nn.DataParallel) else model.module
self.model_run.load_state_dict(checkpoint)
self.model_run.to(self.torch_device).eval()
Expand Down Expand Up @@ -224,8 +224,6 @@ def overlap_add(self, result, x, weights, start, length):
"""
Adds the overlapping part of the result to the result tensor.
"""
x = x.to(result.device)
weights = weights.to(result.device)
result[..., start : start + length] += x[..., :length] * weights[:length]
return result

Expand Down Expand Up @@ -272,13 +270,11 @@ def demix(self, mix: np.ndarray) -> dict:

device = next(self.model_run.parameters()).device

# Transfer to the weighting plate for the same device as the other tensors
window = window.to(device)

with torch.no_grad():
req_shape = (len(self.model_data_cfgdict.training.instruments),) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32).to(device)
counter = torch.zeros(req_shape, dtype=torch.float32).to(device)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)

for i in tqdm(range(0, mix.shape[1], step)):
part = mix[:, i : i + chunk_size]
Expand All @@ -288,8 +284,10 @@ def demix(self, mix: np.ndarray) -> dict:
length = chunk_size
part = part.to(device)
x = self.model_run(part.unsqueeze(0))[0]
x = x.cpu()
# Perform overlap_add on CPU
if i + chunk_size > mix.shape[1]:
# Corrigido para adicionar corretamente ao final do tensor
# Fixed to correctly add to the end of the tensor
result = self.overlap_add(result, x, window, result.shape[-1] - chunk_size, length)
counter[..., result.shape[-1] - chunk_size :] += window[:length]
else:
Expand Down Expand Up @@ -337,7 +335,6 @@ def demix(self, mix: np.ndarray) -> dict:
# It starts as a tensor of zeros and is updated in-place as the model processes each batch.
# The variable holds the combined result of all processed batches, which, after post-processing, represents the separated audio sources.
accumulated_outputs = torch.zeros(num_stems, *mix.shape) if num_stems > 1 else torch.zeros_like(mix)
accumulated_outputs = accumulated_outputs.to(self.torch_device)

with torch.no_grad():
count = 0
Expand All @@ -350,7 +347,9 @@ def demix(self, mix: np.ndarray) -> dict:
# Since single_batch_result can contain multiple output tensors (one for each piece of audio in the batch),
# individual_output is used to iterate through these tensors and accumulate them into accumulated_outputs.
for individual_output in single_batch_result:
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output
individual_output_cpu = individual_output.cpu()
# Accumulate outputs on CPU
accumulated_outputs[..., count * hop_size : count * hop_size + chunk_size] += individual_output_cpu
count += 1

self.logger.debug("Calculating inferenced outputs based on accumulated outputs and overlap")
Expand Down
27 changes: 20 additions & 7 deletions PolUVR/utils/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,14 +460,22 @@ def update_stems(model):
batch_size = gr.Slider(minimum=1, maximum=16, step=1, value=1, label="Batch Size", info="Larger consumes more RAM but may process slightly faster.")

with gr.Accordion("Rename Stems", open=False):
gr.HTML("<h3> NAME - Input File Name </h3>")
gr.HTML("<h3> STEM - Stem Name (e.g., Vocals, Instrumental) </h3>")
gr.HTML("<h3> MODEL - Model Name (e.g., BS-Roformer-Viperx-1297) </h3>")
gr.HTML("<h3> Usage Example: NAME_(STEM)_MODEL </h3>")
gr.HTML("<h3> Output File Name: Music_(Vocals)_BS-Roformer-Viperx-1297 </h3>")
gr.Markdown(
"""
Keys for automatic determination of input file names, stems, and models to simplify the construction of output file names.
Keys:
* **NAME** - Input File Name
* **STEM** - Stem Name (e.g., Vocals, Instrumental)
* **MODEL** - Model Name (e.g., BS-Roformer-Viperx-1297)
Example:
* Usage: NAME_(STEM)_MODEL
* Output File Name: Music_(Vocals)_BS-Roformer-Viperx-1297
"""
)
with gr.Row():
vocals_stem = gr.Textbox(value="NAME_(STEM)_MODEL", label="Vocals Stem", info="Output example: Music_(Vocals)_BS-Roformer-Viperx-1297", placeholder="NAME_(STEM)_MODEL")
instrumental_stem = gr.Textbox(value="NAME_(STEM)_MODEL", label="Instrumental Stem", info="Пример вывода: Music_(Instrumental)_BS-Roformer-Viperx-1297", placeholder="NAME_(STEM)_MODEL")
instrumental_stem = gr.Textbox(value="NAME_(STEM)_MODEL", label="Instrumental Stem", info="Output example: Music_(Instrumental)_BS-Roformer-Viperx-1297", placeholder="NAME_(STEM)_MODEL")
other_stem = gr.Textbox(value="NAME_(STEM)_MODEL", label="Other Stem", info="Output example: Music_(Other)_BS-Roformer-Viperx-1297", placeholder="NAME_(STEM)_MODEL")
with gr.Row():
drums_stem = gr.Textbox(value="NAME_(STEM)_MODEL", label="Drums Stem", info="Output example: Music_(Drums)_BS-Roformer-Viperx-1297", placeholder="NAME_(STEM)_MODEL")
Expand Down Expand Up @@ -606,7 +614,12 @@ def update_stems(model):
)

def main():
app.launch(share=True, debug=True)
app.queue().launch(
share="--share" in sys.argv,
inbrowser="--open" in sys.argv,
debug=True,
show_error=True,
)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ numpy = ">=1.23,<2"
librosa = ">=0.10"
samplerate = "0.1.0"
six = ">=1.16"
torch = "*"
torch = ">=2.3,<2.5"
tqdm = "*"
pydub = ">=0.25"
onnx = ">=1.14"
Expand Down

0 comments on commit 1cc8c8d

Please sign in to comment.