Skip to content

Commit

Permalink
v2.1 dotnet (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksyeo1010 authored and ErisMik committed Dec 5, 2024
1 parent 1621887 commit 2c53a7a
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 91 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/dotnet-demos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ jobs:
- name: Package restore
run: dotnet restore

# ************** REMOVE AFTER RELEASE ********************
- name: Build Local Packages
run: dotnet build && dotnet pack -c Release
working-directory: binding/dotnet/Cheetah

- name: Install Local Packages
run: dotnet add package -s ../../../binding/dotnet/Cheetah/bin/Release Picovoice.Cheetah
# ********************************************************

- name: Dotnet build micdemo
run: dotnet build -c MicDemo.Release

Expand All @@ -60,6 +69,15 @@ jobs:
- name: Package restore
run: dotnet restore

# ************** REMOVE AFTER RELEASE ********************
- name: Build Local Packages
run: dotnet build && dotnet pack -c Release
working-directory: binding/dotnet/Cheetah

- name: Install Local Packages
run: dotnet add package -s ../../../binding/dotnet/Cheetah/bin/Release Picovoice.Cheetah
# ********************************************************

- name: Dotnet build micdemo
run: dotnet build -c MicDemo.Release

Expand Down
4 changes: 2 additions & 2 deletions binding/dotnet/Cheetah/Cheetah.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net8.0;net6.0;netcoreapp3.0;netstandard2.0</TargetFrameworks>
<Version>2.0.2</Version>
<TargetFrameworks>net8.0;net6.0;netcoreapp3.0;netstandard2.0;</TargetFrameworks>
<Version>2.1.0</Version>
<Authors>Picovoice</Authors>
<Company />
<Product>Cheetah Speech-to-Text Engine</Product>
Expand Down
2 changes: 1 addition & 1 deletion binding/dotnet/CheetahTest/CheetahTest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Fastenshtein" Version="1.0.0.8" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.1.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.8" />
<PackageReference Include="coverlet.collector" Version="3.1.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
</ItemGroup>

<ItemGroup>
Expand Down
203 changes: 116 additions & 87 deletions binding/dotnet/CheetahTest/MainTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,76 +12,133 @@ specific language governing permissions and limitations under the License.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;

using Fastenshtein;

using Microsoft.VisualStudio.TestTools.UnitTesting;

using Newtonsoft.Json.Linq;

using Pv;

namespace CheetahTest
{
[TestClass]
public class MainTest
{
private static string ACCESS_KEY;
private static string _accessKey;
private static readonly string ROOT_DIR = Path.Combine(AppContext.BaseDirectory, "../../../../../..");

private static readonly string _relativeDir = AppContext.BaseDirectory;
[ClassInitialize]
public static void ClassInitialize(TestContext _)
{
_accessKey = Environment.GetEnvironmentVariable("ACCESS_KEY");
}

private List<short> GetPcmFromFile(string audioFilePath, int expectedSampleRate)
[Serializable]
private class LanguageTestJson
{
List<short> data = new List<short>();
using (BinaryReader reader = new BinaryReader(File.Open(audioFilePath, FileMode.Open)))
{
reader.ReadBytes(24); // skip over part of the header
Assert.AreEqual(reader.ReadInt32(), expectedSampleRate, "Specified sample rate did not match test file.");
reader.ReadBytes(16); // skip over the rest of the header
public string language { get; set; }
public string audio_file { get; set; }
public string transcript { get; set; }

while (reader.BaseStream.Position != reader.BaseStream.Length)
{
data.Add(reader.ReadInt16());
}
}
public string[] punctuations { get; set; }
public float error_rate { get; set; }
}

return data;
private static JObject LoadJsonTestData()
{
string content = File.ReadAllText(Path.Combine(ROOT_DIR, "resources/.test/test_data.json"));
return JObject.Parse(content);
}

public static IEnumerable<object[]> TestParameters
private static IEnumerable<object[]> LanguageTestParameters
{
get
{
List<object[]> testParameters = new List<object[]>();
JObject testDataJson = LoadJsonTestData();
IList<LanguageTestJson> languageTestJson = ((JArray)testDataJson["tests"]["language_tests"]).ToObject<IList<LanguageTestJson>>();
return languageTestJson
.Select(x => new object[] {
x.language,
x.audio_file,
x.transcript,
x.punctuations,
x.error_rate,
});
}
}

private static string AppendLanguage(string s, string language)
{
return language == "en" ? s : $"{s}_{language}";
}

private static int LevenshteinDistance(string[] transcriptWords, string[] referenceWords)
{
int referenceWordsLen = referenceWords.Length;
int transcriptWordsLen = transcriptWords.Length;

string transcript = "Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel";
string transcriptWithPunctuation = "Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.";
int[,] dp = new int[referenceWordsLen + 1, transcriptWordsLen + 1];

testParameters.Add(new object[]
for (int i = 0; i <= referenceWordsLen; i++) dp[i, 0] = i;
for (int j = 0; j <= transcriptWordsLen; j++) dp[0, j] = j;

for (int i = 1; i <= referenceWordsLen; i++)
{
for (int j = 1; j <= transcriptWordsLen; j++)
{
"en",
"test.wav",
transcript,
transcriptWithPunctuation,
0.025f
});

return testParameters;
int cost = referenceWords[i - 1].ToUpper() == transcriptWords[j - 1].ToUpper() ? 0 : 1;

dp[i, j] = Math.Min(
Math.Min(dp[i - 1, j] + 1, dp[i, j - 1] + 1),
dp[i - 1, j - 1] + cost
);
}
}

return dp[referenceWordsLen, transcriptWordsLen];
}

static float GetErrorRate(string transcript, string referenceTranscript)
=> Levenshtein.Distance(transcript, referenceTranscript) / (float)referenceTranscript.Length;
private static double GetErrorRate(string transcript, string referenceTranscript)
{
string[] transcriptWords = transcript.Split(' ');
string[] referenceTranscriptWords = referenceTranscript.Split(' ');

[ClassInitialize]
public static void ClassInitialize(TestContext _)
int editDistance = LevenshteinDistance(transcriptWords, referenceTranscriptWords);
return (double)editDistance / referenceTranscriptWords.Length;
}

private static string GetModelPath(string language)
{
ACCESS_KEY = Environment.GetEnvironmentVariable("ACCESS_KEY");
return Path.Combine(
ROOT_DIR,
"lib/common",
$"{AppendLanguage("cheetah_params", language)}.pv");
}

private List<short> GetPcmFromFile(string audioFilePath, int expectedSampleRate)
{
List<short> data = new List<short>();
using (BinaryReader reader = new BinaryReader(File.Open(audioFilePath, FileMode.Open)))
{
reader.ReadBytes(24); // skip over part of the header
Assert.AreEqual(reader.ReadInt32(), expectedSampleRate, "Specified sample rate did not match test file.");
reader.ReadBytes(16); // skip over the rest of the header

while (reader.BaseStream.Position != reader.BaseStream.Length)
{
data.Add(reader.ReadInt16());
}
}

return data;
}

[TestMethod]
public void TestVersion()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
Assert.IsFalse(string.IsNullOrWhiteSpace(cheetah?.Version), "Cheetah did not return a valid version number.");
}
Expand All @@ -90,7 +147,7 @@ public void TestVersion()
[TestMethod]
public void TestSampleRate()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
int num = 0;
Assert.IsTrue(int.TryParse(cheetah.SampleRate.ToString(), out num), "Cheetah did not return a valid sample rate.");
Expand All @@ -100,28 +157,29 @@ public void TestSampleRate()
[TestMethod]
public void TestFrameLength()
{
using (Cheetah cheetah = Cheetah.Create(ACCESS_KEY))
using (Cheetah cheetah = Cheetah.Create(_accessKey))
{
int num = 0;
Assert.IsTrue(int.TryParse(cheetah.FrameLength.ToString(), out num), "Cheetah did not return a valid frame length.");
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
[DynamicData(nameof(LanguageTestParameters))]
public void TestProcess(
string language,
string testAudioFile,
string referenceTranscript,
string _,
string[] punctuations,
float targetErrorRate)
{
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: GetModelPath(language),
endpointDurationSec: 0.2f,
enableAutomaticPunctuation: false))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
string testAudioPath = Path.Combine(ROOT_DIR, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
Expand All @@ -140,25 +198,32 @@ public void TestProcess(
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
string normalizedTranscript = referenceTranscript;
foreach (string punctuation in punctuations)
{
normalizedTranscript = normalizedTranscript.Replace(punctuation, "");
}

Assert.IsTrue(GetErrorRate(transcript, normalizedTranscript) <= targetErrorRate);
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
[DynamicData(nameof(LanguageTestParameters))]
public void TestProcessWithPunctuation(
string language,
string testAudioFile,
string _,
string referenceTranscript,
string[] _,
float targetErrorRate)
{
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: GetModelPath(language),
endpointDurationSec: 0.2f,
enableAutomaticPunctuation: true))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
string testAudioPath = Path.Combine(ROOT_DIR, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
Expand All @@ -177,50 +242,14 @@ public void TestProcessWithPunctuation(
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
}
}

[TestMethod]
[DynamicData(nameof(TestParameters))]
public void TestCustomModel(
string language,
string testAudioFile,
string referenceTranscript,
string _,
float targetErrorRate)
{
string testModelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
using (Cheetah cheetah = Cheetah.Create(
accessKey: ACCESS_KEY,
modelPath: testModelPath,
enableAutomaticPunctuation: false))
{
string testAudioPath = Path.Combine(_relativeDir, "resources/audio_samples", testAudioFile);
List<short> pcm = GetPcmFromFile(testAudioPath, cheetah.SampleRate);

int frameLen = cheetah.FrameLength;
int framecount = (int)Math.Floor((float)(pcm.Count / frameLen));

string transcript = "";
for (int i = 0; i < framecount; i++)
{
int start = i * cheetah.FrameLength;
List<short> frame = pcm.GetRange(start, frameLen);
CheetahTranscript transcriptObj = cheetah.Process(frame.ToArray());
transcript += transcriptObj.Transcript;
}
CheetahTranscript finalTranscriptObj = cheetah.Flush();
transcript += finalTranscriptObj.Transcript;

Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) < targetErrorRate);
Assert.IsTrue(GetErrorRate(transcript, referenceTranscript) <= targetErrorRate);
}
}

[TestMethod]
public void TestMessageStack()
{
string modelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
string modelPath = GetModelPath("en");

Cheetah c;
string[] messageList = new string[] { };
Expand Down Expand Up @@ -263,10 +292,10 @@ public void TestMessageStack()
[TestMethod]
public void TestProcessFlushMessageStack()
{
string modelPath = Path.Combine(_relativeDir, "lib/common/cheetah_params.pv");
string modelPath = GetModelPath("en");

Cheetah c = Cheetah.Create(
accessKey: ACCESS_KEY,
accessKey: _accessKey,
modelPath: modelPath,
enableAutomaticPunctuation: false);
short[] testPcm = new short[c.FrameLength];
Expand Down
8 changes: 7 additions & 1 deletion binding/dotnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ using(Cheetah handle = Cheetah.Create(accessKey))
}
```

The model file contains the parameters for the Cheetah engine. You may create bespoke language models using [Picovoice Console](https://console.picovoice.ai/) and then pass in the relevant file.
### Language Model

The Cheetah .NET SDK comes preloaded with a default English language model (`.pv` file).
Default models for other supported languages can be found in [lib/common](../../lib/common).

Create custom language models using the [Picovoice Console](https://console.picovoice.ai/). Here you can train
language models with custom vocabulary and boost words in the existing vocabulary.

```csharp
using Pv;
Expand Down
1 change: 1 addition & 0 deletions resources/.lint/spell-check/dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ LPWSTR
Makefiles
micdemo
NETCOREAPP
Newtonsoft
NumChans
Okhttp
pathbuf
Expand Down

0 comments on commit 2c53a7a

Please sign in to comment.