From 51eddc7992a4c39430a4f0ae3499e1d8efb4963d Mon Sep 17 00:00:00 2001 From: Eric Mikulin Date: Wed, 4 Dec 2024 09:50:37 -0800 Subject: [PATCH] v2.1 Rust (#367) --- binding/rust/Cargo.toml | 9 +-- binding/rust/README.md | 17 +++++ binding/rust/tests/cheetah_tests.rs | 105 ++++++++++++++++++---------- 3 files changed, 92 insertions(+), 39 deletions(-) diff --git a/binding/rust/Cargo.toml b/binding/rust/Cargo.toml index 935cf002..16057dfa 100644 --- a/binding/rust/Cargo.toml +++ b/binding/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pv_cheetah" -version = "2.0.3" +version = "2.1.0" edition = "2018" description = "The Rust bindings for Picovoice's Cheetah library" license = "Apache-2.0" @@ -27,10 +27,11 @@ crate_type = ["lib"] [dependencies] libc = "0.2" -libloading = "0.7" +libloading = "0.8" [dev-dependencies] distance = "0.4.0" -itertools = "0.10" -rodio = "0.15" +itertools = "0.11" +rodio = "0.17" serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } diff --git a/binding/rust/README.md b/binding/rust/README.md index 87b8d2f8..91e19bd2 100644 --- a/binding/rust/README.md +++ b/binding/rust/README.md @@ -71,6 +71,23 @@ Replace `${ACCESS_KEY}` with yours obtained from [Picovoice Console](https://con The model file contains the parameters for the Cheetah engine. You may create bespoke language models using [Picovoice Console](https://console.picovoice.ai/) and then pass in the relevant file. +### Language Model + +The Cheetah Rust SDK comes preloaded with a default English language model (`.pv` file). +Default models for other supported languages can be found in [lib/common](../../lib/common). + +Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train +language models with custom vocabulary and boost words in the existing vocabulary. + +Pass in the `.pv` file via the `.model_path()` Builder argument: +```rust +let leopard: Cheetah = CheetahBuilder::new() + .access_key("${ACCESS_KEY}") + .model_path("${MODEL_FILE_PATH}") + .init() + .expect("Unable to create Cheetah"); +``` + ## Demos The [Cheetah Rust demo project](https://github.com/Picovoice/cheetah/tree/master/demo/rust) is a Rust console app that allows for processing real-time audio (i.e. microphone) and files using Cheetah. diff --git a/binding/rust/tests/cheetah_tests.rs b/binding/rust/tests/cheetah_tests.rs index c1727d34..30545a95 100644 --- a/binding/rust/tests/cheetah_tests.rs +++ b/binding/rust/tests/cheetah_tests.rs @@ -1,5 +1,5 @@ /* - Copyright 2022-2023 Picovoice Inc. + Copyright 2022-2024 Picovoice Inc. You may not use this file except in compliance with the license. A copy of the license is located in the "LICENSE" file accompanying this source. @@ -14,22 +14,59 @@ mod tests { use distance::*; use itertools::Itertools; use rodio::{source::Source, Decoder}; - use serde_json::{json, Value}; + use serde::Deserialize; use std::env; - use std::fs::File; + use std::fs::{read_to_string, File}; use std::io::BufReader; use cheetah::CheetahBuilder; - fn load_test_data() -> Value { - let test_json: Value = json!([{ - "language": "en", - "transcript": "Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel", - "transcript_with_punctuation": "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", - "error_rate": 0.025, - "audio_file": "test.wav" - }]); - test_json + #[derive(Debug, Deserialize)] + struct LanguageTestJson { + language: String, + audio_file: String, + transcript: String, + punctuations: Vec, + error_rate: f32, + } + + #[derive(Debug, Deserialize)] + struct TestsJson { + language_tests: Vec, + } + + #[derive(Debug, Deserialize)] + struct RootJson { + tests: TestsJson, + } + + fn load_test_data() -> TestsJson { + let test_json_path = format!( + "{}{}", + env!("CARGO_MANIFEST_DIR"), + "/../../resources/.test/test_data.json" + ); + let contents: String = + read_to_string(test_json_path).expect("Unable to read test_data.json"); + let root: RootJson = serde_json::from_str(&contents).expect("Failed to parse JSON"); + root.tests + } + + fn append_lang(path: &str, language: &str) -> String { + if language == "en" { + String::from(path) + } else { + format!("{}_{}", path, language) + } + } + + fn model_path_by_language(language: &str) -> String { + format!( + "{}{}{}", + env!("CARGO_MANIFEST_DIR"), + append_lang("/../../lib/common/cheetah_params", language), + ".pv" + ) } fn character_error_rate(transcript: &str, expected_transcript: &str) -> f32 { @@ -38,7 +75,7 @@ mod tests { } fn run_test_process( - _: &str, + language: &str, transcript: &str, test_punctuation: bool, error_rate: f32, @@ -47,6 +84,8 @@ mod tests { let access_key = env::var("PV_ACCESS_KEY") .expect("Pass the AccessKey in using the PV_ACCESS_KEY env variable"); + let model_path = model_path_by_language(language); + let audio_path = format!( "{}{}{}", env!("CARGO_MANIFEST_DIR"), @@ -59,6 +98,7 @@ mod tests { let cheetah = CheetahBuilder::new() .access_key(access_key) + .model_path(model_path) .enable_automatic_punctuation(test_punctuation) .init() .expect("Unable to create Cheetah"); @@ -82,42 +122,37 @@ mod tests { #[test] fn test_process() { - let test_json: Value = load_test_data(); - - for t in test_json.as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript = t["transcript"].as_str().unwrap(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; + let test_json: TestsJson = load_test_data(); - let test_audio = t["audio_file"].as_str().unwrap(); + for t in test_json.language_tests { + let mut transcript = t.transcript; + for p in t.punctuations { + transcript = transcript.replace(&p, "") + } run_test_process( - language, - transcript, + &t.language, + &transcript, false, - error_rate, - &test_audio, + t.error_rate, + &t.audio_file, ); } } #[test] fn test_process_punctuation() { - let test_json: Value = load_test_data(); - - for t in test_json.as_array().unwrap() { - let language = t["language"].as_str().unwrap(); - let transcript_with_punctuation = t["transcript_with_punctuation"].as_str().unwrap(); - let error_rate = t["error_rate"].as_f64().unwrap() as f32; + let test_json: TestsJson = load_test_data(); - let test_audio = t["audio_file"].as_str().unwrap(); + for t in test_json.language_tests { + let transcript = t.transcript; run_test_process( - language, - transcript_with_punctuation, + &t.language, + &transcript, true, - error_rate, - &test_audio, + t.error_rate, + &t.audio_file, ); } }