Skip to content

Commit

Permalink
move it Babylon
Browse files Browse the repository at this point in the history
  • Loading branch information
jwijffels committed Jun 10, 2020
1 parent bae800f commit af0d7b9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 119 deletions.
1 change: 0 additions & 1 deletion src/Makevars
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ SOURCES = word2vec/lib/huffmanTree.cpp \
word2vec/lib/vocabulary.cpp \
word2vec/lib/word2vec.cpp \
rcpp_word2vec.cpp \
rcpp_word2vec_read.cpp \
RcppExports.cpp

OBJECTS = $(SOURCES:.cpp=.o)
Expand Down
1 change: 0 additions & 1 deletion src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ SOURCES = word2vec/lib/huffmanTree.cpp \
word2vec/lib/win/mman.cpp \
word2vec/lib/word2vec.cpp \
rcpp_word2vec.cpp \
rcpp_word2vec_read.cpp \
RcppExports.cpp

OBJECTS = $(SOURCES:.cpp=.o)
Expand Down
104 changes: 104 additions & 0 deletions src/rcpp_word2vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <iostream>
#include <iomanip>
#include "word2vec.hpp"
#include "wordReader.hpp"
#include <unordered_map>

// [[Rcpp::depends(RcppProgress)]]
Expand Down Expand Up @@ -293,3 +294,106 @@ Rcpp::List w2v_nearest_vector(SEXP ptr,
return out;
}

// [[Rcpp::export]]
Rcpp::NumericMatrix w2v_read_binary(const std::string modelFile, bool normalize, std::size_t n) {
try {
const std::string wrongFormatErrMsg = "model: wrong model file format";

// map model file, exception will be thrown on empty file
w2v::fileMapper_t input(modelFile);

// parse header
off_t offset = 0;
// get words number
std::string nwStr;
char ch = 0;
while ((ch = (*(input.data() + offset))) != ' ') {
nwStr += ch;
if (++offset >= input.size()) {
throw std::runtime_error(wrongFormatErrMsg);
}
}

// get vector size
offset++; // skip ' ' char
std::string vsStr;
while ((ch = (*(input.data() + offset))) != '\n') {
vsStr += ch;
if (++offset >= input.size()) {
throw std::runtime_error(wrongFormatErrMsg);
}
}

std::size_t m_mapSize;
uint16_t m_vectorSize;
try {
m_mapSize = static_cast<std::size_t>(std::stoll(nwStr));
m_vectorSize = static_cast<uint16_t>(std::stoi(vsStr));
} catch (...) {
throw std::runtime_error(wrongFormatErrMsg);
}
if(m_mapSize > n){
m_mapSize = n;
}
Rcpp::NumericMatrix embedding(m_mapSize, m_vectorSize);
Rcpp::StringVector embedding_words(m_mapSize);
//std::fill(embedding.begin(), embedding.end(), Rcpp::NumericVector::get_na());

// get pairs of word and vector
offset++; // skip last '\n' char
std::string word;
for (std::size_t i = 0; i < m_mapSize; ++i) {
// get word
word.clear();
while ((ch = (*(input.data() + offset))) != ' ') {
if (ch != '\n') {
word += ch;
}
// move to the next char and check boundaries
if (++offset >= input.size()) {
throw std::runtime_error(wrongFormatErrMsg);
}
}
embedding_words[i] = word;

// skip last ' ' char and check boundaries
if (static_cast<off_t>(++offset + m_vectorSize * sizeof(float)) > input.size()) {
throw std::runtime_error(wrongFormatErrMsg);
}

// get word's vector
std::vector<float> v(m_vectorSize);
std::memcpy(v.data(), input.data() + offset, m_vectorSize * sizeof(float));
offset += m_vectorSize * sizeof(float); // vector size

if(normalize){
// normalize vector
float med = 0.0f;
for (auto const &j:v) {
med += j * j;
}
if (med <= 0.0f) {
throw std::runtime_error("failed to normalize vectors");
}
med = std::sqrt(med / v.size());
for (auto &j:v) {
j /= med;
}
}
for(unsigned int j = 0; j < v.size(); j++){
//embedding(i, j) = (float)((*v)[j]);
embedding(i, j) = v[j];
}

}
rownames(embedding) = embedding_words;
return embedding;
} catch (const std::exception &_e) {
std::string m_errMsg = _e.what();
} catch (...) {
std::string m_errMsg = "model: unknown error";
}
Rcpp::NumericMatrix embedding_default;
return embedding_default;
}

117 changes: 0 additions & 117 deletions src/rcpp_word2vec_read.cpp

This file was deleted.

0 comments on commit af0d7b9

Please sign in to comment.