Skip to content

Commit

Permalink
Fixed bug with VR51 models, added script to calculate model hashes an…
Browse files Browse the repository at this point in the history
…d params for all model files in a directory
  • Loading branch information
beveradb committed Feb 5, 2024
1 parent 7144165 commit ba492d8
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
4 changes: 2 additions & 2 deletions audio_separator/separator/uvr_lib_v5/vr_network/layers_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def forward(self, input_tensor):
# Extract features and prepare for LSTM
hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
h, _ = self.lstm(h)
hidden, _ = self.lstm(hidden)

# Apply dense layer and reshape to match expected output format
hidden = self.dense(h.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
hidden = hidden.reshape(nframes, N, 1, nbins)
hidden = hidden.permute(1, 2, 3, 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self, input_tensor):
aux2 = torch.cat([l2, h2], dim=2)

# Prepare input for the third stage by concatenating all previous outputs with the original input.
f3_in = torch.cat([x, aux1, aux2], dim=1)
f3_in = torch.cat([input_tensor, aux1, aux2], dim=1)

# Process through the third stage network.
f3 = self.stg3_full_band_net(f3_in)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-separator"
version = "0.14.1"
version = "0.14.2"
description = "Easy to use vocal separation, using MDX-Net models from UVR trained by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
Expand Down
98 changes: 98 additions & 0 deletions tools/calculate-model-hashes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3

import os
import sys
import json
import hashlib
import requests

MODEL_CACHE_PATH = "/tmp/audio-separator-models"
VR_MODEL_DATA_LOCAL_PATH = f"{MODEL_CACHE_PATH}/vr_model_data.json"
MDX_MODEL_DATA_LOCAL_PATH = f"{MODEL_CACHE_PATH}/mdx_model_data.json"

MODEL_DATA_URL_PREFIX = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
VR_MODEL_DATA_URL = f"{MODEL_DATA_URL_PREFIX}/vr_model_data/model_data_new.json"
MDX_MODEL_DATA_URL = f"{MODEL_DATA_URL_PREFIX}/mdx_model_data/model_data_new.json"

OUTPUT_PATH = f"{MODEL_CACHE_PATH}/model_hashes.json"


def get_model_hash(model_path):
"""
Get the hash of a model file
"""
# print(f"Getting hash for model at {model_path}")
try:
with open(model_path, "rb") as f:
f.seek(-10000 * 1024, 2) # Move the file pointer 10MB before the end of the file
hash_result = hashlib.md5(f.read()).hexdigest()
# print(f"Hash for {model_path}: {hash_result}")
return hash_result
except IOError:
with open(model_path, "rb") as f:
hash_result = hashlib.md5(f.read()).hexdigest()
# print(f"IOError encountered, hash for {model_path}: {hash_result}")
return hash_result


def download_file_if_missing(url, local_path):
"""
Download a file from a URL if it doesn't exist locally
"""
print(f"Checking if {local_path} needs to be downloaded from {url}")
if not os.path.exists(local_path):
print(f"Downloading {url} to {local_path}")
with requests.get(url, stream=True, timeout=10) as r:
r.raise_for_status()
with open(local_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {url} to {local_path}")
else:
print(f"{local_path} already exists. Skipping download.")


def load_json_data(file_path):
"""
Load JSON data from a file
"""
print(f"Loading JSON data from {file_path}")
try:
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
print(f"Loaded JSON data successfully from {file_path}")
return data
except FileNotFoundError:
print(f"{file_path} not found.")
sys.exit(1)


def iterate_and_hash(directory):
"""
Iterate through a directory and hash all model files
"""
print(f"Iterating through directory {directory} to hash model files")
model_files = [(file, os.path.join(root, file)) for root, _, files in os.walk(directory) for file in files if file.endswith((".pth", ".onnx"))]

download_file_if_missing(VR_MODEL_DATA_URL, VR_MODEL_DATA_LOCAL_PATH)
download_file_if_missing(MDX_MODEL_DATA_URL, MDX_MODEL_DATA_LOCAL_PATH)

vr_model_data = load_json_data(VR_MODEL_DATA_LOCAL_PATH)
mdx_model_data = load_json_data(MDX_MODEL_DATA_LOCAL_PATH)

combined_model_params = {**vr_model_data, **mdx_model_data}

model_info_list = []
for file, file_path in sorted(model_files):
file_hash = get_model_hash(file_path)
model_info = {"file": file, "hash": file_hash, "params": combined_model_params.get(file_hash, "Parameters not found")}
model_info_list.append(model_info)

print(f"Writing model info list to {OUTPUT_PATH}")
with open(OUTPUT_PATH, "w", encoding="utf-8") as json_file:
json.dump(model_info_list, json_file, indent=4)
print(f"Successfully wrote model info list to {OUTPUT_PATH}")


if __name__ == "__main__":
iterate_and_hash(MODEL_CACHE_PATH)

0 comments on commit ba492d8

Please sign in to comment.