From 18c756b89ea610b5ffefb9a570b0854d24570bc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Mazur?= Date: Tue, 18 Feb 2020 11:33:47 +0100 Subject: [PATCH] Add optimized version of getSentenceVector --- .../java/com/github/jfasttext/JFastText.java | 29 +++++++------ .../com/github/jfasttext/JFastTextTest.java | 43 +++++++++++++------ 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/src/main/java/com/github/jfasttext/JFastText.java b/src/main/java/com/github/jfasttext/JFastText.java index bfff410..f728f78 100644 --- a/src/main/java/com/github/jfasttext/JFastText.java +++ b/src/main/java/com/github/jfasttext/JFastText.java @@ -142,34 +142,37 @@ public List predictProba(String text, int k) { return probaPredictions; } + @Deprecated public List getVector(String word) { - FastTextWrapper.RealVector rv = fta.getVector(word); + float[] vector = getArrayVector(word); List wordVec = new ArrayList<>(); - for (int i = 0; i < rv.size(); i++) { - wordVec.add(rv.get(i)); + for (float f : vector) { + wordVec.add(f); } return wordVec; } public float[] getArrayVector(String word) { FastTextWrapper.RealVector rv = fta.getVector(word); - float[] wordVec = new float[(int)rv.size()]; - for (int i = 0; i < rv.size(); i++) { - wordVec[i] = rv.get(i); - } - return wordVec; + return rv.get(); } + @Deprecated public List getSentenceVector(String sentence) { + float[] vector = getArraySentenceVector(sentence); + List sentenceVec = new ArrayList<>(); + for (float f : vector) { + sentenceVec.add(f); + } + return sentenceVec; + } + + public float[] getArraySentenceVector(String sentence) { if (!sentence.endsWith("\n")) { sentence += "\n"; } FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence); - List wordVec = new ArrayList<>(); - for (int i = 0; i < rv.size(); i++) { - wordVec.add(rv.get(i)); - } - return wordVec; + return rv.get(); } public int getNWords() { diff --git a/src/test/java/com/github/jfasttext/JFastTextTest.java b/src/test/java/com/github/jfasttext/JFastTextTest.java index 23dbc2d..1fd7ffc 100644 --- a/src/test/java/com/github/jfasttext/JFastTextTest.java +++ b/src/test/java/com/github/jfasttext/JFastTextTest.java @@ -8,11 +8,11 @@ import java.io.FileInputStream; import java.io.InputStream; import java.net.URL; +import java.util.Arrays; import java.util.List; -import static org.junit.Assert.assertEquals; - import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @FixMethodOrder(MethodSorters.NAME_ASCENDING) @@ -66,15 +66,15 @@ public void test04Predict() throws Exception { System.out.printf("\nText: '%s', label: '%s'\n", text, predictedLabel); } - @Test - public void test04getArrayVector() throws Exception { - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); - String text = "I like soccer"; - float[] predictedArray = jft.getArrayVector(text); - float[] expected = new float[100]; - assertArrayEquals("", predictedArray, expected, 0.1f); - } + @Test + public void test04getArrayVector() throws Exception { + JFastText jft = new JFastText(); + jft.loadModel("src/test/resources/models/supervised.model.bin"); + String text = "I like soccer"; + float[] predictedArray = jft.getArrayVector(text); + float[] expected = new float[100]; + assertArrayEquals("", predictedArray, expected, 0.1f); + } @Test public void test05PredictProba() throws Exception { @@ -106,6 +106,16 @@ public void test07GetVector() throws Exception { } } + @Test + public void test07GetArrayVector() throws Exception { + try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) { + JFastText jft = new JFastText(is); + String word = "soccer"; + float[] vec = jft.getArrayVector(word); + System.out.printf("\nWord embedding vector of '%s': %s\n", word, Arrays.toString(vec)); + } + } + @Test public void test08GetSentenceVector() throws Exception { JFastText jft = new JFastText(); @@ -116,6 +126,16 @@ public void test08GetSentenceVector() throws Exception { assertEquals(expectedSize, vec.size()); } + @Test + public void test08GetArraySentenceVector() throws Exception { + JFastText jft = new JFastText(); + jft.loadModel("src/test/resources/models/supervised.model.bin"); + String word = "soccers"; + float[] vec = jft.getArraySentenceVector(word); + int expectedSize = 100; + assertEquals(expectedSize, vec.length); + } + /** * Test retrieving model's information: words, labels, learning rate, etc. */ @@ -152,7 +172,6 @@ public void test10ModelUnloading() throws Exception { /** * Loads model from specified URL (resource, web, etc.) - * */ @Test public void test10ModelFromURL() throws Exception {