Skip to content

Commit

Permalink
Merge pull request #175 from tattle-made/hotfix
Browse files Browse the repository at this point in the history
Hotfix
  • Loading branch information
duggalsu authored Mar 13, 2024
2 parents b6814b8 + 6b98b8f commit fc508fa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
13 changes: 8 additions & 5 deletions src/core/operators/audio_cnn_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# import numpy as np
import csv
from pathlib import Path
import wget

sample_rate = 32000

Expand All @@ -11,11 +12,13 @@
# Download labels if not exist
if not os.path.isfile(labels_csv_path):
os.makedirs(os.path.dirname(labels_csv_path), exist_ok=True)
os.system(
'wget -O "{}" "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv"'.format(
labels_csv_path
)
)
# os.system(
# 'wget -O "{}" "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv"'.format(
# labels_csv_path
# )
# )
dl_path = "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv"
wget.download(dl_path, out=labels_csv_path)

# Load label
with open(labels_csv_path, "r") as f:
Expand Down
4 changes: 3 additions & 1 deletion src/core/operators/audio_cnn_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# import matplotlib.pyplot as plt
import torch
from pathlib import Path
import wget

from .pytorch_utils import move_data_to_device
from .models import Cnn14 # , Cnn14_DecisionLevelMax
Expand Down Expand Up @@ -39,7 +40,8 @@ def __init__(self, model=None, checkpoint_path=None, device="cuda"):
):
create_folder(os.path.dirname(checkpoint_path))
zenodo_path = "https://github.com/tattle-made/feluda/releases/download/third-party-models/Cnn14_mAP.0.431.pth"
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
# os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
wget.download(zenodo_path, out=checkpoint_path)

# script_dir = os.path.dirname(os.path.abspath(__file__))
# checkpoint_path = os.path.join(script_dir, 'panns_data', 'Cnn14_mAP=0.431.pth')
Expand Down

0 comments on commit fc508fa

Please sign in to comment.