Skip to content

Commit

Permalink
v2.1 java
Browse files Browse the repository at this point in the history
  • Loading branch information
ksyeo1010 committed Dec 2, 2024
1 parent 69550b2 commit 99558aa
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 18 deletions.
3 changes: 2 additions & 1 deletion binding/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ plugins {

ext {
PUBLISH_GROUP_ID = 'ai.picovoice'
PUBLISH_VERSION = '2.0.2'
PUBLISH_VERSION = '2.1.0'
PUBLISH_ARTIFACT_ID = 'cheetah-java'
}

Expand Down Expand Up @@ -84,6 +84,7 @@ if (file("${rootDir}/publish-mavencentral.gradle").exists()) {
}

dependencies {
testImplementation 'com.google.code.gson:gson:2.10.1'
testImplementation 'org.junit.jupiter:junit-jupiter:5.4.2'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
}
Expand Down
164 changes: 150 additions & 14 deletions binding/java/test/ai/picovoice/cheetah/CheetahTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2022-2023 Picovoice Inc.
Copyright 2022-2024 Picovoice Inc.
You may not use this file except in compliance with the license. A copy of the license is
located in the "LICENSE" file accompanying this source.
Expand All @@ -12,14 +12,22 @@

package ai.picovoice.cheetah;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import org.junit.jupiter.api.Test;

import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.stream.Stream;


Expand All @@ -32,6 +40,103 @@
public class CheetahTest {
private final String accessKey = System.getProperty("pvTestingAccessKey");

private static String appendLanguage(String s, String language) {
if (language.equals("en")) {
return s;
}
return s + "_" + language;
}

public static int levenshteinDistance(String[] transcript, String[] reference) {
int m = transcript.length;
int n = reference.length;
int[][] dp = new int[m + 1][n + 1];

for (int i = 0; i <= m; i++) {
dp[i][0] = i;
}

for (int j = 0; j <= n; j++) {
dp[0][j] = j;
}

for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++) {
if (transcript[i - 1].equalsIgnoreCase(reference[j - 1])) {
dp[i][j] = dp[i - 1][j - 1];
} else {
dp[i][j] = 1 + Math.min(dp[i - 1][j - 1], Math.min(dp[i - 1][j], dp[i][j - 1]));
}
}
}
return dp[m][n];
}

public static float getErrorRate(String transcript, String reference) {
String[] transcriptWords = transcript.split("\\s+");
String[] referenceWords = reference.split("\\s+");
int distance = levenshteinDistance(transcriptWords, referenceWords);

return (float) distance / (float) referenceWords.length;
}

private static ProcessTestData[] loadProcessTestData() throws IOException {
final Path testDataPath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/.test")
.resolve("test_data.json");
final String testDataContent = new String(Files.readAllBytes(testDataPath), StandardCharsets.UTF_8);
final JsonObject testDataJson = JsonParser.parseString(testDataContent).getAsJsonObject();

final JsonArray testParameters = testDataJson
.getAsJsonObject("tests")
.getAsJsonArray("language_tests");

final ProcessTestData[] processTestData = new ProcessTestData[testParameters.size()];
for (int i = 0; i < testParameters.size(); i++) {
final JsonObject testData = testParameters.get(i).getAsJsonObject();
final String language = testData.get("language").getAsString();
final String testAudioFile = testData.get("audio_file").getAsString();
final String transcript = testData.get("transcript").getAsString();
final float errorRate = testData.get("error_rate").getAsFloat();

final JsonArray punctuationsJson = testData.getAsJsonArray("punctuations");
final String[] punctuations = new String[punctuationsJson.size()];
for (int j = 0; j < punctuationsJson.size(); j++) {
punctuations[j] = punctuationsJson.get(j).getAsString();
}
processTestData[i] = new ProcessTestData(
language,
testAudioFile,
transcript,
punctuations,
errorRate);
}
return processTestData;
}

private static Stream<Arguments> processTestProvider() throws IOException {
final ProcessTestData[] processTestData = loadProcessTestData();
final ArrayList<Arguments> testArgs = new ArrayList<>();
for (ProcessTestData processTestDataItem : processTestData) {
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
false,
processTestDataItem.errorRate));
testArgs.add(Arguments.of(
processTestDataItem.language,
processTestDataItem.audioFile,
processTestDataItem.transcript,
processTestDataItem.punctuations,
true,
processTestDataItem.errorRate));
}

return testArgs.stream();
}

@Test
void getVersion() throws CheetahException {
Cheetah cheetah = new Cheetah.Builder()
Expand Down Expand Up @@ -84,22 +189,33 @@ void getErrorStack() {
}
}

@ParameterizedTest(name = "test transcribe with automatic punctuation set to ''{0}''")
@MethodSource("transcribeProvider")
void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript) throws Exception {
@ParameterizedTest(name = "test process data for ''{0}'' with punctuation ''{4}''")
@MethodSource("processTestProvider")
void process(
String language,
String testAudioFile,
String referenceTranscript,
String[] punctuations,
boolean enableAutomaticPunctuation,
float targetErrorRate) throws Exception {
String modelPath = Paths.get(System.getProperty("user.dir"))
.resolve(String.format("../../lib/common/%s.pv", appendLanguage("cheetah_params", language)))
.toString();

Cheetah cheetah = new Cheetah.Builder()
.setAccessKey(accessKey)
.setModelPath(modelPath)
.setEnableAutomaticPunctuation(enableAutomaticPunctuation)
.build();

int frameLen = cheetah.getFrameLength();
String audioFilePath = Paths.get(System.getProperty("user.dir"))
.resolve("../../resources/audio_samples/test.wav")
.resolve(String.format("../../resources/audio_samples/%s", testAudioFile))
.toString();
File testAudioPath = new File(audioFilePath);

AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(testAudioPath);
assertEquals(audioInputStream.getFormat().getFrameRate(), 16000);
assertEquals(16000, audioInputStream.getFormat().getFrameRate());

int byteDepth = audioInputStream.getFormat().getFrameSize();
byte[] pcm = new byte[frameLen * byteDepth];
Expand All @@ -116,17 +232,37 @@ void transcribe(boolean enableAutomaticPunctuation, String referenceTranscript)
}
CheetahTranscript finalTranscriptObj = cheetah.flush();
transcript.append(finalTranscriptObj.getTranscript());
assertEquals(referenceTranscript, transcript.toString());

cheetah.delete();

String normalizedTranscript = referenceTranscript;
if (!enableAutomaticPunctuation) {
for (String punctuation : punctuations) {
normalizedTranscript = normalizedTranscript.replace(punctuation, "");
}
}

assertTrue(getErrorRate(transcript.toString(), normalizedTranscript) < targetErrorRate);
}

private static Stream<Arguments> transcribeProvider() {
return Stream.of(
Arguments.of(true,
"Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."),
Arguments.of(false,
"Mr quilter is the apostle of the middle classes and we are glad to welcome his gospel")
);
private static class ProcessTestData {
public final String language;
public final String audioFile;
public final String transcript;
public final String[] punctuations;
public final float errorRate;

public ProcessTestData(
String language,
String audioFile,
String transcript,
String[] punctuations,
float errorRate) {
this.language = language;
this.audioFile = audioFile;
this.transcript = transcript;
this.punctuations = punctuations;
this.errorRate = errorRate;
}
}
}
9 changes: 6 additions & 3 deletions demo/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ plugins {

repositories {
mavenCentral()
maven {
url 'https://s01.oss.sonatype.org/content/repositories/aipicovoice-1349/'
}
}

sourceSets {
Expand All @@ -15,14 +18,14 @@ sourceSets {
}

dependencies {
implementation 'ai.picovoice:cheetah-java:2.0.2'
implementation 'ai.picovoice:cheetah-java:2.1.0'
implementation 'commons-cli:commons-cli:1.4'
}

jar {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.MicDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/FileDemo.class"
Expand All @@ -33,7 +36,7 @@ jar {
task fileDemoJar(type: Jar) {
manifest {
attributes "Main-Class": "ai.picovoice.cheetahdemo.FileDemo",
"Class-Path": "cheetah-2.0.2.jar;commons-cli-1.4.jar"
"Class-Path": "cheetah-2.1.0.jar;commons-cli-1.4.jar"
}
from sourceSets.main.output
exclude "**/MicDemo.class"
Expand Down

0 comments on commit 99558aa

Please sign in to comment.