Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.1 dotnet #361

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/dotnet-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ jobs:
- name: Package restore
run: dotnet restore

# ************** REMOVE AFTER RELEASE ********************
- name: Build Local Packages
run: dotnet build && dotnet pack -c Release
working-directory: binding/dotnet/Cheetah

- name: Install Local Packages
run: dotnet add package -s ../../../binding/dotnet/Cheetah/bin/Release Picovoice.Cheetah
# ********************************************************

- name: Dotnet build micdemo
run: dotnet build -c MicDemo.Release

Expand All @@ -60,6 +69,15 @@ jobs:
- name: Package restore
run: dotnet restore

# ************** REMOVE AFTER RELEASE ********************
- name: Build Local Packages
run: dotnet build && dotnet pack -c Release
working-directory: binding/dotnet/Cheetah

- name: Install Local Packages
run: dotnet add package -s ../../../binding/dotnet/Cheetah/bin/Release Picovoice.Cheetah
# ********************************************************

- name: Dotnet build micdemo
run: dotnet build -c MicDemo.Release

Expand Down
4 changes: 2 additions & 2 deletions binding/dotnet/Cheetah/Cheetah.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net8.0;net6.0;netcoreapp3.0;netstandard2.0</TargetFrameworks>
<Version>2.0.2</Version>
<TargetFrameworks>net8.0;net6.0;netcoreapp3.0;netstandard2.0;</TargetFrameworks>
<Version>2.1.0</Version>
<Authors>Picovoice</Authors>
<Company />
<Product>Cheetah Speech-to-Text Engine</Product>
Expand Down
2 changes: 1 addition & 1 deletion binding/dotnet/CheetahTest/CheetahTest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Fastenshtein" Version="1.0.0.8" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.1.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.8" />
<PackageReference Include="coverlet.collector" Version="3.1.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
</ItemGroup>

<ItemGroup>
Expand Down
203 changes: 116 additions & 87 deletions binding/dotnet/CheetahTest/MainTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,76 +12,133 @@ specific language governing permissions and limitations under the License.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;

using Fastenshtein;

using Microsoft.VisualStudio.TestTools.UnitTesting;

using Newtonsoft.Json.Linq;

using Pv;

namespace CheetahTest
{
[TestClass]
public class MainTest
{
private static string ACCESS_KEY;
private static string _accessKey;
private static readonly string ROOT_DIR = Path.Combine(AppContext.BaseDirectory, "../../../../../..");

private static readonly string _relativeDir = AppContext.BaseDirectory;
[ClassInitialize]
public static void ClassInitialize(TestContext _)
{
_accessKey = Environment.GetEnvironmentVariable("ACCESS_KEY");
}

private List<short> GetPcmFromFile(string audioFilePath, int expectedSampleRate)
[Serializable]
private class LanguageTestJson
{
List<short> data = new List<short>();
using (BinaryReader reader = new BinaryReader(File.Open(audioFilePath, FileMode.Open)))
{
reader.ReadBytes(24); // skip over part of the header
Assert.AreEqual(reader.ReadInt32(), expectedSampleRate, "Specified sample rate did not match test file.");
reader.ReadBytes(16); // skip over the rest of the header
public string language { get; set; }
public string audio_file { get; set; }
public string transcript { get; set; }

while (reader.BaseStream.Position != reader.BaseStream.Length)
{
data.Add(reader.ReadInt16());
}
}
public string[] punctuations { get; set; }
public float error_rate { get; set; }
}

return data;
private static JObject LoadJsonTestData()
{
string content = File.ReadAllText(Path.Combine(ROOT_DIR, "resources/.test/test_data.json"));
return JObject.Parse(content);
}

public static IEnumerable<object[]> TestParameters
private static IEnumerable<object[]> LanguageTestParameters
{
get
{
List<object[]> testParameters = new List<object[]>();
JObject testDataJson = LoadJsonTestData();
IList<LanguageTestJson> languageTestJson = ((JArray)testDataJson["tests"]["language_tests"]).ToObject<IList<LanguageTestJson>>();
return languageTestJson
.Select(x => new object[] {
x.language,
x.audio_file,
x.transcript,
x.punctuations,
x.error_rate,
});
}
}

private static string AppendLanguage(string s, string language)
{
return language == "en" ? s : $"{s}_{language}";
}

private static int LevenshteinDistance(string[] transcriptWords, string[] referenceWords)
{
int referenceWordsLen = referenceWords.Length;
int transcriptWordsLen = transcriptWords.Length;

string transcript = "Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel";
string transcriptWithPunctuation = "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.";
int[,] dp = new int[referenceWordsLen + 1, transcriptWordsLen + 1];

testParameters.Add(new object[]
for (int i = 0; i <= referenceWordsLen; i++) dp[i, 0] = i;
for (int j = 0; j <= transcriptWordsLen; j++) dp[0, j] = j;

for (int i = 1; i <= referenceWordsLen; i++)
{
for (int j = 1; j <= transcriptWordsLen; j++)
{
"en",
"test.wav",
transcript,
transcriptWithPunctuation,
0.025f
});

return testParameters;
int cost = referenceWords[i - 1].ToUpper() == transcriptWords[j - 1].ToUpper() ? 0 : 1;

dp[i, j] = Math.Min(
Math.Min(dp[i - 1, j] + 1, dp[i, j - 1] + 1),
dp[i - 1, j - 1] + cost
);
}
}

return dp[referenceWordsLen, transcriptWordsLen];
}

static float GetErrorRate(string transcript, string referenceTranscript)
=> Levenshtein.Distance(transcript, referenceTranscript) / (float)referenceTranscript.Length;
private static double GetErrorRate(string transcript, string referenceTranscript)
{
string[] transcriptWords = transcript.Split(' ');
string[] referenceTranscriptWords = referenceTranscript.Split(' ');

[ClassInitialize]
public static void ClassInitialize(TestContext _)
int editDistance = LevenshteinDistance(transcriptWords, referenceTranscriptWords);
return (double)editDistance / referenceTranscriptWords.Length;
}

private static string GetModelPath(string language)
{
ACCESS_KEY = Environment.GetEnvironmentVariable("ACCESS_KEY");
return Path.Combine(
ROOT_DIR,
"lib/common",
$"{AppendLanguage("cheetah_params", language)}.pv");
}

private List<short> GetPcmFromFile(string audioFilePath, int expectedSampleRate)
{
List<short> data = new List<short>();
using (BinaryReader reader = new BinaryReader(File.Open(audioFilePath, FileMode.Open)))
{
reader.ReadBytes(24); // skip over part of the header
Assert.AreEqual(reader.ReadInt32(), expectedSampleRate, "Specified sample rate did not match test file.");
reader.ReadBytes(16); // skip over the rest of the header

while (reader.BaseStream.Position != reader.BaseStream.Length)
{
data.Add(reader.ReadInt16());
}
}

return data;
}

[TestMethod]
public void TestVersion()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
Assert.IsFalse(string.IsNullOrWhiteSpace(cheetah?.Version), "Cheetah did not return a valid version number.");
}
Expand All @@ -90,7 +147,7 @@ public void TestVersion()
[TestMethod]
public void TestSampleRate()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
int num = 0;
Assert.IsTrue(int.TryParse(cheetah.SampleRate.ToString(), out num), "Cheetah did not return a valid sample rate.");
Expand All @@ -100,28 +157,29 @@ public void TestSampleRate()
[TestMethod]
public void TestFrameLength()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
int num = 0;
Assert.IsTrue(int.TryParse(cheetah.FrameLength.ToString(), out num), "Cheetah did not return a valid frame length.");
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
[DynamicData(nameof(LanguageTestParameters))]
public void TestProcess(
string language,
string testAudioFile,
string referenceTranscript,
string _,
string[] punctuations,
float targetErrorRate)
{
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: GetModelPath(language),
endpointDurationSec: 0.2f,
enableAutomaticPunctuation: false))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
string testAudioPath = Path.Combine(ROOT_DIR, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
Expand All @@ -140,25 +198,32 @@ public void TestProcess(
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
string normalizedTranscript = referenceTranscript;
foreach (string punctuation in punctuations)
{
normalizedTranscript = normalizedTranscript.Replace(punctuation, "");
}

Assert.IsTrue(GetErrorRate(transcript, normalizedTranscript) <= targetErrorRate);
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
[DynamicData(nameof(LanguageTestParameters))]
public void TestProcessWithPunctuation(
string language,
string testAudioFile,
string _,
string referenceTranscript,
string[] _,
float targetErrorRate)
{
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: GetModelPath(language),
endpointDurationSec: 0.2f,
enableAutomaticPunctuation: true))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
string testAudioPath = Path.Combine(ROOT_DIR, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
Expand All @@ -177,50 +242,14 @@ public void TestProcessWithPunctuation(
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
public void TestCustomModel(
string language,
string testAudioFile,
string referenceTranscript,
string _,
float targetErrorRate)
{
string testModelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
modelPath: testModelPath,
enableAutomaticPunctuation: false))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
int framecount = (int)Math.Floor((float)(pcm.Count / frameLen));

string transcript = "";
for (int i = 0; i < framecount; i++)
{
int start = i * cheetah.FrameLength;
List<short> frame = pcm.GetRange(start, frameLen);
CheetahTranscript transcriptObj = cheetah.Process(frame.ToArray());
transcript += transcriptObj.Transcript;
}
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) <= targetErrorRate);
}
}

[TestMethod]
public void TestMessageStack()
{
string modelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
string modelPath = GetModelPath("en");

Cheetah c;
string[] messageList = new string[] { };
Expand Down Expand Up @@ -263,10 +292,10 @@ public void TestMessageStack()
[TestMethod]
public void TestProcessFlushMessageStack()
{
string modelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
string modelPath = GetModelPath("en");

Cheetah c = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: modelPath,
enableAutomaticPunctuation: false);
short[] testPcm = new short[c.FrameLength];
Expand Down
8 changes: 7 additions & 1 deletion binding/dotnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ using(Cheetah handle = Cheetah.Create(accessKey))
}
```

The model file contains the parameters for the Cheetah engine. You may create bespoke language models using [Picovoice Console](https://console.picovoice.ai/) and then pass in the relevant file.
### Language Model

The Cheetah .NET SDK comes preloaded with a default English language model (`.pv` file).
Default models for other supported languages can be found in [lib/common](../../lib/common).

Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train
language models with custom vocabulary and boost words in the existing vocabulary.

```csharp
using Pv;
Expand Down
1 change: 1 addition & 0 deletions resources/.lint/spell-check/dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ LPWSTR
Makefiles
micdemo
NETCOREAPP
Newtonsoft
NumChans
Okhttp
pathbuf
Expand Down