diff --git a/src/core/operators/audio_cnn_model/config.py b/src/core/operators/audio_cnn_model/config.py index 00acb029..f7d93d6b 100644 --- a/src/core/operators/audio_cnn_model/config.py +++ b/src/core/operators/audio_cnn_model/config.py @@ -3,6 +3,7 @@ # import numpy as np import csv from pathlib import Path +import wget sample_rate = 32000 @@ -11,11 +12,13 @@ # Download labels if not exist if not os.path.isfile(labels_csv_path): os.makedirs(os.path.dirname(labels_csv_path), exist_ok=True) - os.system( - 'wget -O "{}" "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv"'.format( - labels_csv_path - ) - ) + # os.system( + # 'wget -O "{}" "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv"'.format( + # labels_csv_path + # ) + # ) + dl_path = "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" + wget.download(dl_path, out=labels_csv_path) # Load label with open(labels_csv_path, "r") as f: diff --git a/src/core/operators/audio_cnn_model/inference.py b/src/core/operators/audio_cnn_model/inference.py index baaa2541..fbfac149 100644 --- a/src/core/operators/audio_cnn_model/inference.py +++ b/src/core/operators/audio_cnn_model/inference.py @@ -6,6 +6,7 @@ # import matplotlib.pyplot as plt import torch from pathlib import Path +import wget from .pytorch_utils import move_data_to_device from .models import Cnn14 # , Cnn14_DecisionLevelMax @@ -39,7 +40,8 @@ def __init__(self, model=None, checkpoint_path=None, device="cuda"): ): create_folder(os.path.dirname(checkpoint_path)) zenodo_path = "https://github.com/tattle-made/feluda/releases/download/third-party-models/Cnn14_mAP.0.431.pth" - os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) + # os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) + wget.download(zenodo_path, out=checkpoint_path) # script_dir = os.path.dirname(os.path.abspath(__file__)) # checkpoint_path = os.path.join(script_dir, 'panns_data', 'Cnn14_mAP=0.431.pth')