Skip to content

Commit

Permalink
Add optimized version of getSentenceVector
Browse files Browse the repository at this point in the history
  • Loading branch information
GotoFinal committed Feb 18, 2020
1 parent 781b736 commit 18c756b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
29 changes: 16 additions & 13 deletions src/main/java/com/github/jfasttext/JFastText.java
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,37 @@ public List<ProbLabel> predictProba(String text, int k) {
return probaPredictions;
}

@Deprecated
public List<Float> getVector(String word) {
FastTextWrapper.RealVector rv = fta.getVector(word);
float[] vector = getArrayVector(word);
List<Float> wordVec = new ArrayList<>();
for (int i = 0; i < rv.size(); i++) {
wordVec.add(rv.get(i));
for (float f : vector) {
wordVec.add(f);
}
return wordVec;
}

public float[] getArrayVector(String word) {
FastTextWrapper.RealVector rv = fta.getVector(word);
float[] wordVec = new float[(int)rv.size()];
for (int i = 0; i < rv.size(); i++) {
wordVec[i] = rv.get(i);
}
return wordVec;
return rv.get();
}

@Deprecated
public List<Float> getSentenceVector(String sentence) {
float[] vector = getArraySentenceVector(sentence);
List<Float> sentenceVec = new ArrayList<>();
for (float f : vector) {
sentenceVec.add(f);
}
return sentenceVec;
}

public float[] getArraySentenceVector(String sentence) {
if (!sentence.endsWith("\n")) {
sentence += "\n";
}
FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence);
List<Float> wordVec = new ArrayList<>();
for (int i = 0; i < rv.size(); i++) {
wordVec.add(rv.get(i));
}
return wordVec;
return rv.get();
}

public int getNWords() {
Expand Down
43 changes: 31 additions & 12 deletions src/test/java/com/github/jfasttext/JFastTextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import java.io.FileInputStream;
import java.io.InputStream;
import java.net.URL;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

@FixMethodOrder(MethodSorters.NAME_ASCENDING)
Expand Down Expand Up @@ -66,15 +66,15 @@ public void test04Predict() throws Exception {
System.out.printf("\nText: '%s', label: '%s'\n", text, predictedLabel);
}

@Test
public void test04getArrayVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String text = "I like soccer";
float[] predictedArray = jft.getArrayVector(text);
float[] expected = new float[100];
assertArrayEquals("", predictedArray, expected, 0.1f);
}
@Test
public void test04getArrayVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String text = "I like soccer";
float[] predictedArray = jft.getArrayVector(text);
float[] expected = new float[100];
assertArrayEquals("", predictedArray, expected, 0.1f);
}

@Test
public void test05PredictProba() throws Exception {
Expand Down Expand Up @@ -106,6 +106,16 @@ public void test07GetVector() throws Exception {
}
}

@Test
public void test07GetArrayVector() throws Exception {
try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) {
JFastText jft = new JFastText(is);
String word = "soccer";
float[] vec = jft.getArrayVector(word);
System.out.printf("\nWord embedding vector of '%s': %s\n", word, Arrays.toString(vec));
}
}

@Test
public void test08GetSentenceVector() throws Exception {
JFastText jft = new JFastText();
Expand All @@ -116,6 +126,16 @@ public void test08GetSentenceVector() throws Exception {
assertEquals(expectedSize, vec.size());
}

@Test
public void test08GetArraySentenceVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String word = "soccers";
float[] vec = jft.getArraySentenceVector(word);
int expectedSize = 100;
assertEquals(expectedSize, vec.length);
}

/**
* Test retrieving model's information: words, labels, learning rate, etc.
*/
Expand Down Expand Up @@ -152,7 +172,6 @@ public void test10ModelUnloading() throws Exception {

/**
* Loads model from specified URL (resource, web, etc.)
*
*/
@Test
public void test10ModelFromURL() throws Exception {
Expand Down

0 comments on commit 18c756b

Please sign in to comment.