Skip to content

Commit

Permalink
Return probs in audio tagging onnx models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 10, 2024
1 parent fa5d861 commit 0d17cad
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
10 changes: 6 additions & 4 deletions egs/audioset/AT/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def forward(
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tensor containing:
- logits, A 2-D tensor of shape (N, num_classes)
- probs, A 2-D tensor of shape (N, num_classes)
"""
x, x_lens = self.encoder_embed(x, x_lens)
Expand All @@ -177,7 +177,8 @@ def forward(
# Note that this is slightly different from model.py for better
# support of onnx
logits = logits.mean(dim=1)
return logits
probs = logits.sigmoid()
return probs


def export_audio_tagging_model_onnx(
Expand Down Expand Up @@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx(
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"logits": {0: "N"},
"probs": {0: "N"},
},
)

meta_data = {
"model_type": "zipformer2_at",
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "zipformer2 audio tagger",
"url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
}
logging.info(f"meta_data: {meta_data}")

Expand Down
21 changes: 11 additions & 10 deletions egs/audioset/AT/zipformer/onnx_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
Usage of this script:
repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
repo=$(basename $repo_url)
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo/exp
git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
popd
for m in model.onnx model.int8.onnx; do
python3 zipformer/onnx_pretrained.py \
--model-filename $repo/exp/model.onnx \
--label-dict $repo/data/class_labels_indices.csv \
--model-filename $repo/model.onnx \
--label-dict $repo/class_labels_indices.csv \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav \
$repo/test_wavs/3.wav \
Expand Down Expand Up @@ -125,7 +125,7 @@ def __call__(
A 2-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a Tensor:
- logits, its shape is (N, num_classes)
- probs, its shape is (N, num_classes)
"""
out = self.model.run(
[
Expand Down Expand Up @@ -208,13 +208,14 @@ def main():
)

feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
logits = model(features, feature_lengths)
probs = model(features, feature_lengths)

for filename, logit in zip(args.sound_files, logits):
topk_prob, topk_index = logit.sigmoid().topk(5)
for filename, prob in zip(args.sound_files, probs):
topk_prob, topk_index = prob.topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
logging.info(
f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}"
f"{filename}: Top 5 predicted labels are {topk_labels} with "
f"probability of {topk_prob.tolist()}"
)

logging.info("Decoding Done")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dill
onnx>=1.15.0
onnxruntime>=1.16.3
onnxoptimizer
onnxsim

# style check session:
black==22.3.0
Expand Down

0 comments on commit 0d17cad

Please sign in to comment.