diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..d2f95ec --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,14 @@ +include cuda_setup.py +include requirements.txt +include pyproject.toml +recursive-include cpp/src/cuw2v/ *.cu +recursive-include cpp/src/culda/ *.cu +recursive-include cpp/src/ioutils/ *.cc +recursive-include cpp/include/cuw2v/ *.cuh +recursive-include cpp/include/cuw2v/ *.hpp +recursive-include cpp/include/culda/ *.cuh +recursive-include cpp/include/culda/ *.hpp +recursive-include cpp/include/ioutils/ *.cuh +recursive-include cpp/include/ioutils/ *.hpp +recursive-include 3rd/json11/ * +recursive-include 3rd/spdlog/ * diff --git a/README.md b/README.md index 34a25c7..1b5cb95 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@ +### Introduction + +This project is to speed up various ML models (e.g. topic modeling, word embedding, etc) by CUDA. It would be nice to think of it as [gensim](https://github.com/RaRe-Technologies/gensim)'s GPU version project. As a starting step, I implemented the most widely used word embedding model, the [word2vec](https://arxiv.org/pdf/1301.3781.pdf) model, and the most representative topic model, the [LDA (Latent Dirichlet Allocation)](https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf) model. + ### How to install +- install from source ```shell # clone repo and submodules @@ -14,3 +19,63 @@ python -m grpc_tools.protoc --python_out cusim/ --proto_path cusim/proto/ config # install python setup.py install ``` + +- pip installation will be available soon + +### How to use + +- `examples/example_w2v.py`, `examples/example_lda.py` and `examples/README.md` will be very helpful to understand the usage. +- paremeter description can be seen in `cusim/proto/config.proto` + +### Performance + +- [AWS g4dn 2xlarge instance](https://aws.amazon.com/ec2/instance-types/g4/) is used to the experiment. (One NVIDIA T4 GPU with 8 vcpus, Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz) +- results can be reproduced by simply running `examples/example_w2v.py` and `examples/example_lda.py` +- To evaluate w2v model, I used `evaluate_word_pairs` function ([ref link](https://radimrehurek.com/gensim/auto_examples/tutorials/run_word2vec.html#evaluating)) in gensim, note that better performance on WS-353 test set does not necessarily mean that the model will workbetter in application as desribed on the link. However, it is good to be measured quantitively and fast training time will be at least very objective measure of the performaance. + - I trained W2V model on `quora-duplicat-questions` dataset from gensim downloader api on GPU with cusim and compare the performance (both speed and model quality) with gensim. +- To evaluate LDA model, I found there is no good way to measure the quality of traing results quantitatively. But we can check the model by looking at the top words of each topic. Also, we can compare the training time quantitatively. +- W2V (skip gram, hierarchical softmax) + +| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) | +|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:| +| training time (sec) | 892.596 | 544.212 | 310.727 | 226.472 | **16.162** | +| pearson | 0.487832 | 0.487696 | 0.482821 | 0.487136 | **0.492101** | +| spearman | 0.500846 | 0.506214 | 0.501048 | **0.506718** | 0.479468 | + +- W2V (skip gram, negative sampling) + +| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) | +|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:| +| training time (sec) | 586.545 | 340.489 | 220.804 | 146.23 | **33.9173** | +| pearson | 0.354448 | 0.353952 | 0.352398 | 0.352925 | **0.360436** | +| spearman | 0.369146 | 0.369365 | **0.370565** | 0.365822 | 0.355204 | + +- W2V (CBOW, hierarchical softmax) + +| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) | +|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:| +| training time (sec) | 250.135 | 155.121 | 103.57 | 73.8073 | **6.20787** | +| pearson | 0.309651 | 0.321803 | 0.324854 | 0.314255 | **0.480298** | +| spearman | 0.294047 | 0.308723 | 0.318293 | 0.300591 | **0.480971** | + +- W2V (CBOW, negative sampling) + +| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) | +|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:| +| training time (sec) | 176.923 | 100.369 | 69.7829 | 49.9274 | **9.90391** | +| pearson | 0.18772 | 0.193152 | 0.204509 | 0.187924 | **0.368202** | +| spearman | 0.243975 | 0.24587 | 0.260531 | 0.237441 | **0.358042** | + +- LDA (`nytimes` dataset from https://archive.ics.uci.edu/ml/datasets/bag+of+words) + - I found that setting `workers` variable in gensim LdaMulticore does not work properly (it uses all cores in instance anyway), so I just compared the speed between cusim with single GPU and gensim with 8 vcpus. + - One can compare the quality of modeling by looking at `examples/cusim.topics.txt` and `examples/gensim.topics.txt`. + +| attr | gensim (8 vpus) | cusim (NVIDIA T4)| +|:--------------------|------------------:|--------:| +| training time (sec) | 447.376 | **76.6972** | + +### Future tasks + +- support half precision +- support multi device (multi device implementation on LDA model will not be that hard, while multi device training on w2v may require some considerations) +- implement other models such as FastText, BERT, etc diff --git a/cpp/include/culda/cuda_lda_kernels.cuh b/cpp/include/culda/cuda_lda_kernels.cuh index 02dbb37..ce451e6 100644 --- a/cpp/include/culda/cuda_lda_kernels.cuh +++ b/cpp/include/culda/cuda_lda_kernels.cuh @@ -26,26 +26,36 @@ float Digamma(float x) { } __global__ void EstepKernel( - const int* cols, const int* indptr, const bool* vali, - const int num_cols, const int num_indptr, + const int* cols, const int* indptr, + const bool* vali, const float* counts, + const bool init_gamma, const int num_cols, const int num_indptr, const int num_topics, const int num_iters, - float* gamma, float* new_gamma, float* phi, const float* alpha, const float* beta, - float* grad_alpha, float* new_beta, - float* train_losses, float* vali_losses, int* mutex) { + float* gamma, float* grad_alpha, float* new_beta, + float* train_losses, float* vali_losses, int* locks) { // storage for block - float* _gamma = gamma + num_topics * blockIdx.x; - float* _new_gamma = new_gamma + num_topics * blockIdx.x; - float* _phi = phi + num_topics * blockIdx.x; + extern __shared__ float shared_memory[]; + float* _new_gamma = &shared_memory[0]; + float* _phi = &shared_memory[num_topics]; + float* _loss_vec = &shared_memory[num_topics * 2]; + float* _vali_phi_sum = &shared_memory[num_topics * 3]; + float* _grad_alpha = grad_alpha + num_topics * blockIdx.x; for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) { int beg = indptr[i], end = indptr[i + 1]; - // initialize gamma - for (int j = threadIdx.x; j < num_topics; j += blockDim.x) - _gamma[j] = alpha[j] + (end - beg) / num_topics; + float* _gamma = gamma + num_topics * i; + if (init_gamma) { + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) { + _gamma[j] = alpha[j] + (end - beg) / num_topics; + } + } __syncthreads(); + + // initiate phi sum for validation data for computing vali loss + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) + _vali_phi_sum[j] = 0.0f; // iterate E step for (int j = 0; j < num_iters; ++j) { @@ -58,7 +68,7 @@ __global__ void EstepKernel( for (int k = beg; k < end; ++k) { const int w = cols[k]; const bool _vali = vali[k]; - + const float c = counts[k]; // compute phi if (not _vali or j + 1 == num_iters) { for (int l = threadIdx.x; l < num_topics; l += blockDim.x) @@ -70,37 +80,52 @@ __global__ void EstepKernel( for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { _phi[l] /= phi_sum; - if (not _vali) _new_gamma[l] += _phi[l]; + + // update gamma for train data and phi_sum for computing loss + if (_vali) + _vali_phi_sum[l] += _phi[l] * c; + else + _new_gamma[l] += _phi[l] * c; + } __syncthreads(); } if (j + 1 == num_iters) { - // write access of w th vector of new_beta - if (threadIdx.x == 0) { - while (atomicCAS(&mutex[w], 0, 1)) {} - } + // update beta for train data + if (not _vali) { + // write access of w th vector of new_beta + if (threadIdx.x == 0) { + while (atomicCAS(&locks[w], 0, 1)) {} + } - __syncthreads(); + __syncthreads(); + for (int l = threadIdx.x; l < num_topics; l += blockDim.x) + new_beta[w * num_topics + l] += _phi[l] * c; + __syncthreads(); + + // release lock + if (threadIdx.x == 0) locks[w] = 0; + __syncthreads(); + } + + // comput loss and reset shared mem + // see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf for (int l = threadIdx.x; l < num_topics; l += blockDim.x) { - if (j + 1 == num_iters) { - if (not _vali) new_beta[w * num_topics + l] += _phi[l]; - _phi[l] *= beta[w * num_topics + l]; - } + _loss_vec[l] = logf(fmaxf(beta[w * num_topics + l], EPS)); + _loss_vec[l] -= logf(fmaxf(_phi[l], EPS)); + _loss_vec[l] *= _phi[l]; } __syncthreads(); - - // release lock - if (threadIdx.x == 0) mutex[w] = 0; - __syncthreads(); - - float p = fmaxf(EPS, ReduceSum(_phi, num_topics)); + float _loss = ReduceSum(_loss_vec, num_topics) * c; if (threadIdx.x == 0) { - if (_vali) - vali_losses[blockIdx.x] += logf(p); + if (_vali) + vali_losses[blockIdx.x] += _loss; else - train_losses[blockIdx.x] += logf(p); - } + train_losses[blockIdx.x] += _loss; + } + __syncthreads(); + } __syncthreads(); } @@ -110,9 +135,23 @@ __global__ void EstepKernel( _gamma[k] = _new_gamma[k] + alpha[k]; __syncthreads(); } + + // update gradient of alpha and loss from E[log(theta)] float gamma_sum = ReduceSum(_gamma, num_topics); - for (int j = threadIdx.x; j < num_topics; j += blockDim.x) - _grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum)); + for (int j = threadIdx.x; j < num_topics; j += blockDim.x) { + float Elogthetad = Digamma(_gamma[j]) - Digamma(gamma_sum); + _grad_alpha[j] += Elogthetad; + _new_gamma[j] *= Elogthetad; + _vali_phi_sum[j] *= Elogthetad; + } + + // see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf + float train_loss = ReduceSum(_new_gamma, num_topics); + float vali_loss = ReduceSum(_vali_phi_sum, num_topics); + if (threadIdx.x == 0) { + train_losses[blockIdx.x] += train_loss; + vali_losses[blockIdx.x] += vali_loss; + } __syncthreads(); } diff --git a/cpp/include/culda/culda.hpp b/cpp/include/culda/culda.hpp index ff52dca..cbcd370 100644 --- a/cpp/include/culda/culda.hpp +++ b/cpp/include/culda/culda.hpp @@ -65,8 +65,11 @@ class CuLDA { void LoadModel(float* alpha, float* beta, float* grad_alpha, float* new_beta, const int num_words); std::pair FeedData( - const int* indices, const int* indptr, const bool* vali, - const int num_indices, const int num_indptr, const int num_iters); + const int* indices, const int* indptr, + const bool* vali, const float* counts, + float* gamma, const bool init_gamma, + const int num_indices, const int num_indptr, + const int num_iters); void Pull(); void Push(); int GetBlockCnt(); @@ -78,8 +81,7 @@ class CuLDA { std::unique_ptr logger_container_; thrust::device_vector dev_alpha_, dev_beta_; thrust::device_vector dev_grad_alpha_, dev_new_beta_; - thrust::device_vector dev_gamma_, dev_new_gamma_, dev_phi_; - thrust::device_vector dev_mutex_; + thrust::device_vector dev_locks_; float *alpha_, *beta_, *grad_alpha_, *new_beta_; int block_cnt_, block_dim_; diff --git a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh index 8046cf3..fe8d3b0 100644 --- a/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_base_kernels.cuh @@ -6,6 +6,8 @@ #pragma once #include "utils/cuda_utils_kernels.cuh" +#define MAX_EXP 20 + namespace cusim { @@ -13,7 +15,7 @@ __inline__ __device__ void PositiveFeedback(const float* vec1, float* vec2, float* grad, float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; - float dot = Dot(vec1, vec2, num_dims); + float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims))); if (threadIdx.x == 0) { float exp_dot = expf(-dot); g = exp_dot / (1 + exp_dot) * lr; @@ -32,7 +34,7 @@ __inline__ __device__ void NegativeFeedback(const float* vec1, float* vec2, float* grad, float& loss_nume, float& loss_deno, const int num_dims, const float lr) { static __shared__ float g; - float dot = Dot(vec1, vec2, num_dims); + float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims))); if (threadIdx.x == 0) { float exp_dot = expf(dot); g = exp_dot / (1 + exp_dot) * lr; diff --git a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh index c7aaca5..8cecc03 100644 --- a/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh @@ -39,11 +39,11 @@ __global__ void W2VHsSgKernel( __syncthreads(); int beg2 = max(beg, j - window_size + reduced_windows); int end2 = min(end, j + window_size - reduced_windows + 1); - float* _emb_in = emb_in + num_dims * cols[j]; for (int k = beg2; k < end2; ++k) { if (k == j) continue; - int beg3 = hs_indptr[cols[k]]; - int end3 = hs_indptr[cols[k] + 1]; + float* _emb_in = emb_in + num_dims * cols[k]; + int beg3 = hs_indptr[cols[j]]; + int end3 = hs_indptr[cols[j] + 1]; for (int l = beg3; l < end3; ++l) { if (codes[l]) { PositiveFeedback(_emb_in, emb_out + num_dims * points[l], @@ -55,7 +55,7 @@ __global__ void W2VHsSgKernel( __syncthreads(); } for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { - emb_in[num_dims * cols[j] + l] += grad[l]; + _emb_in[l] += grad[l]; grad[l] = 0.0f; } __syncthreads(); @@ -70,7 +70,7 @@ __global__ void W2VHsCbowKernel( const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs, float* emb_in, float* emb_out, float* loss_nume, float* loss_deno, - const bool use_mean, const float lr) { + const bool cbow_mean, const float lr) { default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; @@ -98,7 +98,7 @@ __global__ void W2VHsCbowKernel( grad[k] = 0.0f; cbow[k] = 0.0f; } - + // compute cbow for (int k = beg2; k < end2; ++k) { if (k == j) continue; @@ -106,7 +106,7 @@ __global__ void W2VHsCbowKernel( cbow[l] += emb_in[num_dims * cols[k] + l]; } } - if (use_mean) { + if (cbow_mean) { for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { cbow[k] /= (end2 - beg2 - 1); } @@ -126,8 +126,8 @@ __global__ void W2VHsCbowKernel( __syncthreads(); } - // normalize grad if use_mean = true - if (use_mean) { + // normalize grad if cbow_mean = true + if (cbow_mean) { for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { grad[k] /= (end2 - beg2 - 1); } diff --git a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh index 8f6bef0..c8cf020 100644 --- a/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh +++ b/cpp/include/cuw2v/cuda_w2v_ns_kernels.cuh @@ -39,20 +39,21 @@ __global__ void W2VNegSgKernel( __syncthreads(); int beg2 = max(beg, j - window_size + reduced_windows); int end2 = min(end, j + window_size - reduced_windows + 1); - float* _emb_in = emb_in + num_dims * cols[j]; for (int k = beg2; k < end2; ++k) { if (k == j) continue; - PositiveFeedback(_emb_in, emb_out + num_dims * cols[k], + float* _emb_in = emb_in + num_dims * cols[k]; + PositiveFeedback(_emb_in, emb_out + num_dims * cols[j], grad, _loss_nume, _loss_deno, num_dims, lr); for (int l = 0; l < neg; ++l) { if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthreads(); + if (neg_word == cols[j]) continue; NegativeFeedback(_emb_in, emb_out + num_dims * neg_word, grad, _loss_nume, _loss_deno, num_dims, lr); } __syncthreads(); for (int l = threadIdx.x; l < num_dims; l += blockDim.x) { - emb_in[num_dims * cols[j] + l] += grad[l]; + _emb_in[l] += grad[l]; grad[l] = 0.0f; } __syncthreads(); @@ -66,7 +67,7 @@ __global__ void W2VNegCbowKernel( const int* random_table, default_random_engine* rngs, const int random_size, const int num_indptr, const int num_dims, const int neg, const int window_size, float* emb_in, float* emb_out, - float* loss_nume, float* loss_deno, const bool use_mean, const float lr) { + float* loss_nume, float* loss_deno, const bool cbow_mean, const float lr) { default_random_engine& rng = rngs[blockIdx.x]; float& _loss_nume = loss_nume[blockIdx.x]; @@ -104,7 +105,7 @@ __global__ void W2VNegCbowKernel( cbow[l] += emb_in[num_dims * cols[k] + l]; } } - if (use_mean) { + if (cbow_mean) { for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { cbow[k] /= (end2 - beg2 - 1); } @@ -119,13 +120,14 @@ __global__ void W2VNegCbowKernel( for (int k = 0; k < neg; ++k){ if (threadIdx.x == 0) neg_word = random_table[dist_neg(rng)]; __syncthreads(); + if (neg_word == cols[j]) continue; NegativeFeedback(cbow, emb_out + num_dims * neg_word, grad, _loss_nume, _loss_deno, num_dims, lr); } __syncthreads(); - // normalize grad if use_mean = true - if (use_mean) { + // normalize grad if cbow_mean = true + if (cbow_mean) { for (int k = threadIdx.x; k < num_dims; k += blockDim.x) { grad[k] /= (end2 - beg2 - 1); } diff --git a/cpp/include/cuw2v/cuw2v.hpp b/cpp/include/cuw2v/cuw2v.hpp index fba7f2b..e444e21 100644 --- a/cpp/include/cuw2v/cuw2v.hpp +++ b/cpp/include/cuw2v/cuw2v.hpp @@ -41,9 +41,7 @@ class CuW2V { bool Init(std::string opt_path); void LoadModel(float* emb_in, float* emb_out); void BuildHuffmanTree(const float* word_count, const int num_words); - void BuildRandomTable(const float* word_count, const int num_words, - const int table_size, const int num_threads); - int GetBlockCnt(); + void BuildRandomTable(const double* word_count, const int num_words, const int table_size); std::pair FeedData(const int* cols, const int* indptr, const int num_cols, const int num_indptr); void Pull(); @@ -64,12 +62,12 @@ class CuW2V { thrust::device_vector dev_points_, dev_hs_indptr_; // related to negative sampling / hierarchical softmax and skip gram / cbow - bool sg_, use_mean_; + bool sg_, cbow_mean_; int neg_; // variables to construct random table thrust::device_vector dev_random_table_; - int random_size_, table_seed_, cuda_seed_; + int random_size_, seed_; thrust::device_vector dev_rngs_; }; diff --git a/cpp/include/utils/ioutils.hpp b/cpp/include/utils/ioutils.hpp index 54ca2e1..3ba63d4 100644 --- a/cpp/include/utils/ioutils.hpp +++ b/cpp/include/utils/ioutils.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -30,11 +32,14 @@ class IoUtils { IoUtils(); ~IoUtils(); bool Init(std::string opt_path); - int LoadStreamFile(std::string filepath); + int64_t LoadStreamFile(std::string filepath); std::pair ReadStreamForVocab(int num_lines, int num_threads); std::pair TokenizeStream(int num_lines, int num_threads); void GetWordVocab(int min_count, std::string keys_path, std::string count_path); void GetToken(int* rows, int* cols, int* indptr); + std::tuple ReadBagOfWordsHeader(std::string filepath); + void ReadBagOfWordsContent(int64_t* rows, int* cols, float* counts, const int num_lines); + private: void ParseLine(std::string line, std::vector& line_vec); void ParseLineImpl(std::string line, std::vector& line_vec); @@ -42,13 +47,14 @@ class IoUtils { std::vector> cols_; std::vector indptr_; std::mutex global_lock_; - std::ifstream stream_fin_; + std::ifstream fin_; json11::Json opt_; std::shared_ptr logger_; std::unique_ptr logger_container_; std::unordered_map word_idmap_, word_count_; std::vector word_list_; - int num_lines_, remain_lines_; + int64_t num_lines_, remain_lines_; + bool lower_; }; // class IoUtils } // namespace cusim diff --git a/cpp/src/culda/culda.cu b/cpp/src/culda/culda.cu index c92bfeb..c5b0ff5 100644 --- a/cpp/src/culda/culda.cu +++ b/cpp/src/culda/culda.cu @@ -55,50 +55,52 @@ void CuLDA::LoadModel(float* alpha, float* beta, // copy to device thrust::copy(grad_alpha_, grad_alpha_ + block_cnt_ * num_topics_, dev_grad_alpha_.begin()); thrust::copy(new_beta_, new_beta_ + num_words_ * num_topics_, dev_new_beta_.begin()); - dev_gamma_.resize(num_topics_ * block_cnt_); - dev_new_gamma_.resize(num_topics_ * block_cnt_); - dev_phi_.resize(num_topics_ * block_cnt_); - - // set mutex - dev_mutex_.resize(num_words_); - std::vector host_mutex(num_words_, 0); - thrust::copy(host_mutex.begin(), host_mutex.end(), dev_mutex_.begin()); + // set locks + dev_locks_.resize(num_words_); + std::vector host_locks(num_words_, 0); + thrust::copy(host_locks.begin(), host_locks.end(), dev_locks_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); } std::pair CuLDA::FeedData( - const int* cols, const int* indptr, const bool* vali, - const int num_cols, const int num_indptr, const int num_iters) { + const int* cols, const int* indptr, + const bool* vali, const float* counts, float* gamma, + const bool init_gamma, const int num_cols, const int num_indptr, + const int num_iters) { // copy feed data to GPU memory thrust::device_vector dev_cols(num_cols); thrust::device_vector dev_indptr(num_indptr + 1); thrust::device_vector dev_vali(num_cols); + thrust::device_vector dev_counts(num_cols); + thrust::device_vector dev_gamma(num_indptr * num_topics_); thrust::device_vector dev_train_losses(block_cnt_, 0.0f); thrust::device_vector dev_vali_losses(block_cnt_, 0.0f); thrust::copy(cols, cols + num_cols, dev_cols.begin()); thrust::copy(indptr, indptr + num_indptr + 1, dev_indptr.begin()); thrust::copy(vali, vali + num_cols, dev_vali.begin()); + thrust::copy(counts, counts + num_cols, dev_counts.begin()); + thrust::copy(gamma, gamma + num_indptr * num_topics_, dev_gamma.begin()); CHECK_CUDA(cudaDeviceSynchronize()); DEBUG0("copy feed data to GPU memory"); // run E step in GPU - EstepKernel<<>>( + EstepKernel<<>>( thrust::raw_pointer_cast(dev_cols.data()), thrust::raw_pointer_cast(dev_indptr.data()), thrust::raw_pointer_cast(dev_vali.data()), - num_cols, num_indptr, num_topics_, num_iters, - thrust::raw_pointer_cast(dev_gamma_.data()), - thrust::raw_pointer_cast(dev_new_gamma_.data()), - thrust::raw_pointer_cast(dev_phi_.data()), + thrust::raw_pointer_cast(dev_counts.data()), + init_gamma, num_cols, num_indptr, num_topics_, num_iters, thrust::raw_pointer_cast(dev_alpha_.data()), thrust::raw_pointer_cast(dev_beta_.data()), + thrust::raw_pointer_cast(dev_gamma.data()), thrust::raw_pointer_cast(dev_grad_alpha_.data()), thrust::raw_pointer_cast(dev_new_beta_.data()), thrust::raw_pointer_cast(dev_train_losses.data()), thrust::raw_pointer_cast(dev_vali_losses.data()), - thrust::raw_pointer_cast(dev_mutex_.data())); + thrust::raw_pointer_cast(dev_locks_.data())); CHECK_CUDA(cudaDeviceSynchronize()); DEBUG0("run E step in GPU"); @@ -106,6 +108,7 @@ std::pair CuLDA::FeedData( std::vector train_losses(block_cnt_), vali_losses(block_cnt_); thrust::copy(dev_train_losses.begin(), dev_train_losses.end(), train_losses.begin()); thrust::copy(dev_vali_losses.begin(), dev_vali_losses.end(), vali_losses.begin()); + thrust::copy(dev_gamma.begin(), dev_gamma.end(), gamma); CHECK_CUDA(cudaDeviceSynchronize()); DEBUG0("pull loss values"); diff --git a/cpp/src/cuw2v/cuw2v.cu b/cpp/src/cuw2v/cuw2v.cu index 20f6640..b419cd9 100644 --- a/cpp/src/cuw2v/cuw2v.cu +++ b/cpp/src/cuw2v/cuw2v.cu @@ -49,54 +49,39 @@ bool CuW2V::Init(std::string opt_path) { block_dim_ = opt_["block_dim"].int_value(); block_cnt_ = opt_["hyper_threads"].number_value() * (dev_info_.cores / block_dim_); sg_ = opt_["skip_gram"].bool_value(); - use_mean_ = opt_["use_mean"].bool_value(); + cbow_mean_ = opt_["cbow_mean"].bool_value(); window_size_ = opt_["window_size"].int_value(); lr_ = opt_["lr"].number_value(); // if zero, we will use hierarchical softmax neg_ = opt_["neg"].int_value(); - // random seed - table_seed_ = opt_["table_seed"].int_value(); - cuda_seed_ = opt_["cuda_seed"].int_value(); + // random seed + seed_ = opt_["seed"].int_value(); dev_rngs_.resize(block_cnt_); InitRngsKernel<<>>( - thrust::raw_pointer_cast(dev_rngs_.data()), cuda_seed_); + thrust::raw_pointer_cast(dev_rngs_.data()), seed_); INFO("num_dims: {}, block_dim: {}, block_cnt: {}, objective type: {}, neg: {}", num_dims_, block_dim_, block_cnt_, sg_? "skip gram": "cbow", neg_); return true; } -void CuW2V::BuildRandomTable(const float* word_count, const int num_words, - const int table_size, const int num_threads) { +void CuW2V::BuildRandomTable(const double* word_count, const int num_words, const int table_size) { num_words_ = num_words; - random_size_ = table_size; - std::vector acc; - float cumsum = 0; + std::vector host_random_table; for (int i = 0; i < num_words; ++i) { - acc.push_back(cumsum); - cumsum += word_count[i]; + int weight = std::max(1, static_cast(word_count[i] * static_cast(table_size))); + for (int j = 0; j < weight; ++j) + host_random_table.push_back(i); } - + + random_size_ = host_random_table.size(); dev_random_table_.resize(random_size_); - std::vector host_random_table(table_size); - #pragma omp parallel num_threads(num_threads) - { - const unsigned int table_seed = table_seed_ + omp_get_thread_num(); - std::mt19937 rng(table_seed); - std::uniform_real_distribution dist(0.0f, cumsum); - #pragma omp for schedule(static) - for (int i = 0; i < random_size_; ++i) { - float r = dist(rng); - int pos = std::lower_bound(acc.begin(), acc.end(), r) - acc.begin(); - host_random_table[i] = pos; - } - } - table_seed_ += num_threads; - thrust::copy(host_random_table.begin(), host_random_table.end(), dev_random_table_.begin()); CHECK_CUDA(cudaDeviceSynchronize()); + + INFO("random table initialzied, size: {} => {}", table_size, random_size_); } void CuW2V::BuildHuffmanTree(const float* word_count, const int num_words) { @@ -181,10 +166,6 @@ void CuW2V::LoadModel(float* emb_in, float* emb_out) { CHECK_CUDA(cudaDeviceSynchronize()); } -int CuW2V::GetBlockCnt() { - return block_cnt_; -} - std::pair CuW2V::FeedData(const int* cols, const int* indptr, const int num_cols, const int num_indptr) { @@ -224,7 +205,7 @@ std::pair CuW2V::FeedData(const int* cols, const int* indptr, thrust::raw_pointer_cast(dev_emb_out_.data()), thrust::raw_pointer_cast(dev_loss_nume.data()), thrust::raw_pointer_cast(dev_loss_deno.data()), - use_mean_, lr_); + cbow_mean_, lr_); } } else { if (sg_) { @@ -255,7 +236,7 @@ std::pair CuW2V::FeedData(const int* cols, const int* indptr, thrust::raw_pointer_cast(dev_emb_out_.data()), thrust::raw_pointer_cast(dev_loss_nume.data()), thrust::raw_pointer_cast(dev_loss_deno.data()), - use_mean_, lr_); + cbow_mean_, lr_); } diff --git a/cpp/src/utils/ioutils.cc b/cpp/src/utils/ioutils.cc index 53833a3..2946536 100644 --- a/cpp/src/utils/ioutils.cc +++ b/cpp/src/utils/ioutils.cc @@ -25,6 +25,7 @@ bool IoUtils::Init(std::string opt_path) { if (not err_cmt.empty()) return false; opt_ = _opt; logger_container_->set_log_level(opt_["c_log_level"].int_value()); + lower_ = opt_["lower"].bool_value(); return true; } @@ -42,7 +43,7 @@ void IoUtils::ParseLineImpl(std::string line, std::vector& ret) { ret.push_back(element); element.clear(); } else { - element += std::tolower(line[i]); + element += (lower_? std::tolower(line[i]): line[i]); } } if (element.size() > 0) { @@ -50,16 +51,16 @@ void IoUtils::ParseLineImpl(std::string line, std::vector& ret) { } } -int IoUtils::LoadStreamFile(std::string filepath) { - INFO("read gensim file to generate vocabulary: {}", filepath); - if (stream_fin_.is_open()) stream_fin_.close(); - stream_fin_.open(filepath.c_str()); - int count = 0; +int64_t IoUtils::LoadStreamFile(std::string filepath) { + INFO("read stream file to generate vocabulary: {}", filepath); + if (fin_.is_open()) fin_.close(); + fin_.open(filepath.c_str()); + int64_t count = 0; std::string line; - while (getline(stream_fin_, line)) + while (getline(fin_, line)) count++; - stream_fin_.close(); - stream_fin_.open(filepath.c_str()); + fin_.close(); + fin_.open(filepath.c_str()); num_lines_ = count; remain_lines_ = num_lines_; INFO("number of lines: {}", num_lines_); @@ -67,7 +68,7 @@ int IoUtils::LoadStreamFile(std::string filepath) { } std::pair IoUtils::TokenizeStream(int num_lines, int num_threads) { - int read_lines = std::min(num_lines, remain_lines_); + int read_lines = static_cast(std::min(static_cast(num_lines), remain_lines_)); if (not read_lines) return {0, 0}; remain_lines_ -= read_lines; cols_.clear(); @@ -83,7 +84,7 @@ std::pair IoUtils::TokenizeStream(int num_lines, int num_threads) { // get line thread-safely { std::unique_lock lock(global_lock_); - getline(stream_fin_, line); + getline(fin_, line); } // seems to be bottle-neck @@ -118,7 +119,7 @@ void IoUtils::GetToken(int* rows, int* cols, int* indptr) { } std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) { - int read_lines = std::min(num_lines, remain_lines_); + int read_lines = static_cast(std::min(static_cast(num_lines), remain_lines_)); remain_lines_ -= read_lines; #pragma omp parallel num_threads(num_threads) { @@ -130,7 +131,7 @@ std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) // get line thread-safely { std::unique_lock lock(global_lock_); - getline(stream_fin_, line); + getline(fin_, line); } // seems to be bottle-neck @@ -150,7 +151,7 @@ std::pair IoUtils::ReadStreamForVocab(int num_lines, int num_threads) } } } - if (not remain_lines_) stream_fin_.close(); + if (not remain_lines_) fin_.close(); return {read_lines, word_count_.size()}; } @@ -179,4 +180,42 @@ void IoUtils::GetWordVocab(int min_count, std::string keys_path, std::string cou fout1.close(); fout2.close(); } +std::tuple IoUtils::ReadBagOfWordsHeader(std::string filepath) { + INFO("read bag of words file: {} (format reference: https://archive.ics.uci.edu/ml/datasets/bag+of+words)", + filepath); + if (fin_.is_open()) fin_.close(); + fin_.open(filepath.c_str()); + std::string line; + std::stringstream sstr; + int64_t num_docs, nnz; + int num_words; + getline(fin_, line); + sstr << line; sstr >> num_docs; sstr.clear(); + getline(fin_, line); + num_words = std::stoi(line); + getline(fin_, line); + sstr << line; sstr >> nnz; sstr.clear(); + return {num_docs, num_words, nnz}; +} + +void IoUtils::ReadBagOfWordsContent(int64_t* rows, int* cols, float* counts, const int num_lines) { + if (not fin_.is_open()) throw std::runtime_error("file is not open"); + std::string line; + std::stringstream sstr; + int64_t row; + int col; + float count; + std::vector line_vec; + for (int i = 0; i < num_lines; ++i) { + getline(fin_, line); + ParseLine(line, line_vec); + sstr << line_vec[0]; sstr >> row; sstr.clear(); + col = std::stoi(line_vec[1]); + count = std::stof(line_vec[2]); + rows[i] = row - 1; cols[i] = col - 1; counts[i] = count; + line_vec.clear(); + } + if (fin_.eof()) fin_.close(); +} + } // namespace cusim diff --git a/cusim/culda/bindings.cc b/cusim/culda/bindings.cc index f85b2d5..adc5bc2 100644 --- a/cusim/culda/bindings.cc +++ b/cusim/culda/bindings.cc @@ -55,21 +55,38 @@ class CuLDABind { _new_beta.mutable_data(0), num_words); } - std::pair FeedData(py::object& cols, py::object& indptr, py::object& vali, const int num_iters) { + std::pair FeedData(py::object& cols, + py::object& indptr, py::object& vali, py::object& counts, + py::object& gamma, const bool init_gamma, + const int num_iters) { int_array _cols(cols); int_array _indptr(indptr); bool_array _vali(vali); + float_array _counts(counts); + float_array _gamma(gamma); auto cols_buffer = _cols.request(); auto indptr_buffer = _indptr.request(); auto vali_buffer = _vali.request(); - if (cols_buffer.ndim != 1 or indptr_buffer.ndim != 1 or vali_buffer.ndim != 1 - or cols_buffer.shape[0] != vali_buffer.shape[0]) { - throw std::runtime_error("invalid cols or indptr"); + auto counts_buffer = _counts.request(); + auto gamma_buffer = _gamma.request(); + if (cols_buffer.ndim != 1 or + indptr_buffer.ndim != 1 or + vali_buffer.ndim != 1 or + counts_buffer.ndim != 1 or + gamma_buffer.ndim != 2) { + throw std::runtime_error("invalid ndim"); } int num_cols = cols_buffer.shape[0]; int num_indptr = indptr_buffer.shape[0] - 1; - return obj_.FeedData(_cols.data(0), _indptr.data(0), _vali.data(0), - num_cols, num_indptr, num_iters); + + if (vali_buffer.shape[0] != num_cols or + counts_buffer.shape[0] != num_cols or + gamma_buffer.shape[0] != num_indptr) { + throw std::runtime_error("invalid length"); + } + return obj_.FeedData(_cols.data(0), _indptr.data(0), + _vali.data(0), _counts.data(0), _gamma.mutable_data(0), + init_gamma, num_cols, num_indptr, num_iters); } void Pull() { @@ -98,7 +115,9 @@ PYBIND11_PLUGIN(culda_bind) { py::arg("alpha"), py::arg("beta"), py::arg("grad_alpha"), py::arg("new_beta")) .def("feed_data", &CuLDABind::FeedData, - py::arg("cols"), py::arg("indptr"), py::arg("vali"), py::arg("num_iters")) + py::arg("cols"), py::arg("indptr"), py::arg("vali"), + py::arg("counts"), py::arg("gamma"), + py::arg("init_gamma"), py::arg("num_iters")) .def("pull", &CuLDABind::Pull) .def("push", &CuLDABind::Push) .def("get_block_cnt", &CuLDABind::GetBlockCnt) diff --git a/cusim/culda/pyculda.py b/cusim/culda/pyculda.py index dbd0da3..ef51aad 100644 --- a/cusim/culda/pyculda.py +++ b/cusim/culda/pyculda.py @@ -9,6 +9,8 @@ from os.path import join as pjoin import json +import atexit +import shutil import tempfile import h5py @@ -44,27 +46,29 @@ def __init__(self, opt=None): self.alpha, self.beta, self.grad_alpha, self.new_beta = \ None, None, None, None + self.tmp_dirs = [] + atexit.register(self.remove_tmp) + def preprocess_data(self): if self.opt.skip_preprocess: return - iou = IoUtils() - if not self.opt.processed_data_dir: - self.opt.processed_data_dir = tempfile.TemporaryDirectory().name - iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, - self.opt.processed_data_dir) + iou = IoUtils(aux.proto_to_dict(self.opt.io)) + if not self.opt.processed_data_path: + data_dir = tempfile.TemporaryDirectory().name + self.tmp_dirs.append(data_dir) + self.opt.processed_data_path = pjoin(data_dir, "token.h5") + iou.convert_bow_to_h5(self.opt.data_path, self.opt.processed_data_path) def init_model(self): - # load voca - data_dir = self.opt.processed_data_dir - self.logger.info("load key from %s", pjoin(data_dir, "keys.txt")) - with open(pjoin(data_dir, "keys.txt"), "rb") as fin: - self.words = [line.strip() for line in fin] - self.num_words = len(self.words) - - # count number of docs - h5f = h5py.File(pjoin(data_dir, "token.h5"), "r") + # count number of docs and load voca + assert os.path.exists(self.opt.processed_data_path) + assert os.path.exists(self.opt.keys_path) + h5f = h5py.File(self.opt.processed_data_path, "r") self.num_docs = h5f["indptr"].shape[0] - 1 h5f.close() + with open(self.opt.keys_path, "rb") as fin: + self.words = [line.strip().decode("utf8") for line in fin] + self.num_words = len(self.words) self.logger.info("number of words: %d, docs: %d", self.num_words, self.num_docs) @@ -87,20 +91,34 @@ def init_model(self): self.logger.info("grad alpha %s, new beta %s initialized", self.grad_alpha.shape, self.new_beta.shape) + # set h5 file path to backup gamma + if not self.opt.gamma_path: + data_dir = tempfile.TemporaryDirectory().name + self.tmp_dirs.append(data_dir) + self.opt.gamma_path = pjoin(data_dir, "gamma.h5") + self.logger.info("backup gamma to %s", self.opt.gamma_path) + os.makedirs(os.path.dirname(self.opt.gamma_path), exist_ok=True) + h5f = h5py.File(self.opt.gamma_path, "w") + h5f.create_dataset("gamma", shape=(self.num_docs, self.opt.num_topics), + dtype=np.float32) + h5f.close() + # push it to gpu self.obj.load_model(self.alpha, self.beta, self.grad_alpha, self.new_beta) def train_model(self): self.preprocess_data() self.init_model() - h5f = h5py.File(pjoin(self.opt.processed_data_dir, "token.h5"), "r") + h5f = h5py.File(self.opt.processed_data_path, "r") for epoch in range(1, self.opt.epochs + 1): + gamma_h5f = h5py.File(self.opt.gamma_path, "r+") self.logger.info("Epoch %d / %d", epoch, self.opt.epochs) - self._train_e_step(h5f) + self._train_e_step(h5f, gamma_h5f["gamma"], epoch) self._train_m_step() + gamma_h5f.close() h5f.close() - def _train_e_step(self, h5f): + def _train_e_step(self, h5f, gamma_h5f, epoch): offset, size = 0, h5f["cols"].shape[0] pbar = aux.Progbar(size, stateful_metrics=["train_loss", "vali_loss"]) train_loss_nume, train_loss_deno = 0, 0 @@ -115,26 +133,31 @@ def _train_e_step(self, h5f): beg, end = indptr[0], indptr[-1] indptr -= beg cols = h5f["cols"][beg:end] + counts = h5f["counts"][beg:end] vali = (h5f["vali"][beg:end] < self.opt.vali_p).astype(np.bool) - offset = next_offset + gamma = gamma_h5f[offset:next_offset, :] # call cuda kernel train_loss, vali_loss = \ - self.obj.feed_data(cols, indptr, vali, self.opt.num_iters_in_e_step) + self.obj.feed_data(cols, indptr.astype(np.int32), + vali, counts, gamma, + epoch == 1 or self.opt.reuse_gamma, + self.opt.num_iters_in_e_step) + gamma_h5f[offset:next_offset, :] = gamma # accumulate loss train_loss_nume -= train_loss vali_loss_nume -= vali_loss - vali_cnt = np.count_nonzero(vali) - train_cnt = len(vali) - vali_cnt - train_loss_deno += train_cnt - vali_loss_deno += vali_cnt + train_loss_deno += np.sum(counts[~vali]) + vali_loss_deno += np.sum(counts[vali]) train_loss = train_loss_nume / (train_loss_deno + EPS) vali_loss = vali_loss_nume / (vali_loss_deno + EPS) # update progress bar pbar.update(end, values=[("train_loss", train_loss), ("vali_loss", vali_loss)]) + offset = next_offset + if end == size: break @@ -162,10 +185,27 @@ def _train_m_step(self): self.obj.push() - def save_model(self, model_path): - self.logger.info("save model path: %s", model_path) - h5f = h5py.File(model_path, "w") + def save_h5_model(self, filepath, chunk_size=10000): + self.logger.info("save h5 format model path to %s", filepath) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + h5f = h5py.File(filepath, "w") h5f.create_dataset("alpha", data=self.alpha) h5f.create_dataset("beta", data=self.beta) - h5f.create_dataset("keys", data=np.array(self.words)) + h5f.create_dataset("keys", data=np.array([word.encode("utf") + for word in self.words])) + gamma = h5f.create_dataset("gamma", dtype=np.float32, + shape=(self.num_docs, self.opt.num_topics)) + h5f_gamma = h5py.File(self.opt.gamma_path, "r") + for offset in range(0, self.num_docs, chunk_size): + next_offset = min(self.num_docs, offset + chunk_size) + gamma[offset:next_offset, :] = h5f_gamma["gamma"][offset:next_offset, :] + h5f_gamma.close() h5f.close() + + def remove_tmp(self): + if not self.opt.remove_tmp: + return + for tmp_dir in self.tmp_dirs: + if os.path.exists(tmp_dir): + self.logger.info("remove %s", tmp_dir) + shutil.rmtree(tmp_dir) diff --git a/cusim/cuw2v/bindings.cc b/cusim/cuw2v/bindings.cc index 3ca45d6..4c2237f 100644 --- a/cusim/cuw2v/bindings.cc +++ b/cusim/cuw2v/bindings.cc @@ -13,6 +13,7 @@ namespace py = pybind11; typedef py::array_t float_array; +typedef py::array_t double_array; typedef py::array_t int_array; class CuW2VBind { @@ -37,14 +38,14 @@ class CuW2VBind { return obj_.LoadModel(_emb_in.mutable_data(0), _emb_out.mutable_data(0)); } - void BuildRandomTable(py::object& word_count, int table_size, int num_threads) { - float_array _word_count(word_count); + void BuildRandomTable(py::object& word_count, int table_size) { + double_array _word_count(word_count); auto wc_buffer = _word_count.request(); if (wc_buffer.ndim != 1) { throw std::runtime_error("invalid word count"); } int num_words = wc_buffer.shape[0]; - obj_.BuildRandomTable(_word_count.data(0), num_words, table_size, num_threads); + obj_.BuildRandomTable(_word_count.data(0), num_words, table_size); } void BuildHuffmanTree(py::object& word_count) { @@ -74,10 +75,6 @@ class CuW2VBind { obj_.Pull(); } - int GetBlockCnt() { - return obj_.GetBlockCnt(); - } - private: cusim::CuW2V obj_; }; @@ -94,10 +91,9 @@ PYBIND11_PLUGIN(cuw2v_bind) { py::arg("cols"), py::arg("indptr")) .def("pull", &CuW2VBind::Pull) .def("build_random_table", &CuW2VBind::BuildRandomTable, - py::arg("word_count"), py::arg("table_size"), py::arg("num_threads")) + py::arg("word_count"), py::arg("table_size")) .def("build_huffman_tree", &CuW2VBind::BuildHuffmanTree, py::arg("word_count")) - .def("get_block_cnt", &CuW2VBind::GetBlockCnt) .def("__repr__", [](const CuW2VBind &a) { return ""; diff --git a/cusim/cuw2v/pycuw2v.py b/cusim/cuw2v/pycuw2v.py index f2bd265..7db96ff 100644 --- a/cusim/cuw2v/pycuw2v.py +++ b/cusim/cuw2v/pycuw2v.py @@ -9,6 +9,8 @@ from os.path import join as pjoin import json +import atexit +import shutil import tempfile import h5py @@ -41,13 +43,16 @@ def __init__(self, opt=None): self.words, self.word_count, self.num_words, self.num_docs = \ None, None, None, None self.emb_in, self.emb_out = None, None + self.tmp_dirs = [] + atexit.register(self.remove_tmp) def preprocess_data(self): if self.opt.skip_preprocess: return - iou = IoUtils() + iou = IoUtils(aux.proto_to_dict(self.opt.io)) if not self.opt.processed_data_dir: self.opt.processed_data_dir = tempfile.TemporaryDirectory().name + self.tmp_dirs.append(self.opt.processed_data_dir) iou.convert_stream_to_h5(self.opt.data_path, self.opt.word_min_count, self.opt.processed_data_dir) @@ -58,11 +63,10 @@ def init_model(self): count_path = pjoin(data_dir, "count.txt") self.logger.info("load key, count from %s, %s", keys_path, count_path) with open(keys_path, "rb") as fin: - self.words = [line.strip() for line in fin] + self.words = [line.strip().decode("utf8") for line in fin] with open(count_path, "rb") as fin: - self.word_count = np.array([float(line.strip()) for line in fin], - dtype=np.float32) - self.word_count = np.power(self.word_count, self.opt.count_power) + self.word_count = np.array([int(line.strip()) for line in fin], + dtype=np.int64) self.num_words = len(self.words) assert len(self.words) == len(self.word_count) @@ -74,22 +78,29 @@ def init_model(self): self.logger.info("number of words: %d, docs: %d", self.num_words, self.num_docs) + # normalize word count + word_count = np.power(self.word_count, self.opt.count_power, + dtype=np.float64) + word_count /= np.sum(word_count) if self.opt.neg: - self.obj.build_random_table( \ - self.word_count, self.opt.random_size, self.opt.num_threads) + self.obj.build_random_table(word_count, self.opt.random_size) else: - self.obj.build_huffman_tree(self.word_count) + self.obj.build_huffman_tree(word_count.astype(np.float32)) # random initialize alpha and beta np.random.seed(self.opt.seed) - self.emb_in = np.random.normal( \ + scale = 1 / np.sqrt(self.opt.num_dims) + self.emb_in = np.random.normal(loc=0, scale=scale, \ size=(self.num_words, self.opt.num_dims)).astype(np.float32) out_words = self.num_words if self.opt.neg else self.num_words - 1 - self.emb_out = np.random.uniform( \ + self.emb_out = np.random.normal(loc=0, scale=scale, \ size=(out_words, self.opt.num_dims)).astype(np.float32) self.logger.info("emb_in %s, emb_out %s initialized", self.emb_in.shape, self.emb_out.shape) + if self.opt.pretrained_model.filename: + self.load_word2vec_format(**aux.proto_to_dict(self.opt.pretrained_model)) + # push it to gpu self.obj.load_model(self.emb_in, self.emb_out) @@ -121,7 +132,7 @@ def _train_epoch(self, h5f): # call cuda kernel _loss_nume, _loss_deno = \ - self.obj.feed_data(cols, indptr) + self.obj.feed_data(cols, indptr.astype(np.int32)) # accumulate loss loss_nume += _loss_nume @@ -132,3 +143,72 @@ def _train_epoch(self, h5f): pbar.update(end, values=[("loss", loss)]) if end == size: break + + def save_h5_model(self, filename): + self.logger.info("save h5 format model to %s", filename) + os.makedirs(os.path.dirname(filename), exist_ok=True) + h5f = h5py.File(filename, "w") + h5f.create_dataset("emb_in", data=self.emb_in) + h5f.create_dataset("emb_out", data=self.emb_out) + h5f.create_dataset("keys", data=np.array([word.encode("utf") + for word in self.words])) + h5f.close() + + def save_word2vec_format(self, filename, binary=False, prefix=""): + self.logger.info("save word2vec format model to %s, " + "binary: %s, prefix: '%s'", filename, binary, prefix) + # save model compatible with gensim and original w2v code by Google + with open(filename, "wb") as fout: + fout.write(f"{self.num_words} {self.opt.num_dims}\n".encode("utf8")) + for idx, word in enumerate(self.words): + vec = self.emb_in[idx] + if binary: + fout.write(f"{prefix}{word} ".encode("utf8") + vec.tobytes()) + else: + fout.write(f"{prefix}{word} " + f"{' '.join(repr(val) for val in vec)}\n".encode("utf8")) + + def load_word2vec_format(self, filename, binary=False, + symmetry=False, no_header=False): + self.logger.info("load pretrained model from %s", filename) + # copy pretrained model to emb_out as well only if + # we use negative sampling, NOT hierarchical softmax + assert not symmetry or self.opt.neg, "no symmetry in hierarchical softmax" + + # read variable + vector_dict = {} + with open(filename, "rb") as fin: + if not no_header: + fin.readline() # throw one line + for line in fin: + if binary: + key, vec = line.split() + vector_dict[key] = np.fromstring(vec, dtype=np.float32) + else: + line_vec = line.strip().split() + key = line_vec[0].decode("utf8") + vec = np.array([float(val) for val in line_vec[1:]], + dtype=np.float32) + vector_dict[key] = vec + + # copy to variable + loaded_cnt = 0 + word_idmap = {word: idx for idx, word in enumerate(self.words)} + for key, vec in vector_dict.items(): + assert len(vec) == self.opt.num_dims + if key not in word_idmap: + continue + idx = word_idmap[key] + loaded_cnt += 1 + self.emb_in[idx, :] = vec + if symmetry: + self.emb_out[idx, :] = vec + self.logger.info("loaded count: %d", loaded_cnt) + + def remove_tmp(self): + if not self.opt.remove_tmp: + return + for tmp_dir in self.tmp_dirs: + if os.path.exists(tmp_dir): + self.logger.info("remove %s", tmp_dir) + shutil.rmtree(tmp_dir) diff --git a/cusim/ioutils/bindings.cc b/cusim/ioutils/bindings.cc index 06204f8..3fae0b8 100644 --- a/cusim/ioutils/bindings.cc +++ b/cusim/ioutils/bindings.cc @@ -14,6 +14,7 @@ namespace py = pybind11; typedef py::array_t float_array; typedef py::array_t int_array; +typedef py::array_t int64_array; class IoUtilsBind { public: @@ -23,7 +24,7 @@ class IoUtilsBind { return obj_.Init(opt_path); } - int LoadStreamFile(std::string filepath) { + int64_t LoadStreamFile(std::string filepath) { return obj_.LoadStreamFile(filepath); } @@ -46,6 +47,26 @@ class IoUtilsBind { obj_.GetToken(_rows.mutable_data(0), _cols.mutable_data(0), _indptr.mutable_data(0)); } + std::tuple ReadBagOfWordsHeader(std::string filepath) { + return obj_.ReadBagOfWordsHeader(filepath); + } + + void ReadBagOfWordsContent(py::object& rows, py::object& cols, + py::object counts) { + int64_array _rows(rows); + int_array _cols(cols); + float_array _counts(counts); + auto rows_buffer = _rows.request(); + auto cols_buffer = _cols.request(); + auto counts_buffer = _counts.request(); + int num_lines = rows_buffer.shape[0]; + if (cols_buffer.shape[0] != num_lines or counts_buffer.shape[0] != num_lines) { + throw std::runtime_error("invalid shape"); + } + obj_.ReadBagOfWordsContent(_rows.mutable_data(0), + _cols.mutable_data(0), _counts.mutable_data(0), num_lines); + } + private: cusim::IoUtils obj_; }; @@ -65,6 +86,10 @@ PYBIND11_PLUGIN(ioutils_bind) { py::arg("min_count"), py::arg("keys_path"), py::arg("count_path")) .def("get_token", &IoUtilsBind::GetToken, py::arg("indices"), py::arg("indptr"), py::arg("offset")) + .def("read_bag_of_words_header", &IoUtilsBind::ReadBagOfWordsHeader, + py::arg("filepath")) + .def("read_bag_of_words_content", &IoUtilsBind::ReadBagOfWordsContent, + py::arg("rows"), py::arg("cols"), py::arg("counts")) .def("__repr__", [](const IoUtilsBind &a) { return ""; diff --git a/cusim/ioutils/pyioutils.py b/cusim/ioutils/pyioutils.py index 71b52fd..dd8f945 100644 --- a/cusim/ioutils/pyioutils.py +++ b/cusim/ioutils/pyioutils.py @@ -64,7 +64,7 @@ def convert_stream_to_h5(self, filepath, min_count, out_dir, processed = 0 h5f = h5py.File(token_path, "w") rows = h5f.create_dataset("rows", shape=(chunk_indices,), - maxshape=(None,), dtype=np.int32, + maxshape=(None,), dtype=np.int64, chunks=(chunk_indices,)) cols = h5f.create_dataset("cols", shape=(chunk_indices,), maxshape=(None,), dtype=np.int32, @@ -73,7 +73,7 @@ def convert_stream_to_h5(self, filepath, min_count, out_dir, maxshape=(None,), dtype=np.float32, chunks=(chunk_indices,)) indptr = h5f.create_dataset("indptr", shape=(full_num_lines + 1,), - dtype=np.int32, chunks=True) + dtype=np.int64, chunks=True) processed, offset = 1, 0 indptr[0] = 0 while True: @@ -84,16 +84,78 @@ def convert_stream_to_h5(self, filepath, min_count, out_dir, _indptr = np.empty(shape=(read_lines,), dtype=np.int32) self.obj.get_token(_rows, _cols, _indptr) rows.resize((offset + data_size,)) - rows[offset:offset + data_size] = _rows + (processed - 1) + rows[offset:offset + data_size] = \ + _rows.astype(np.int64) + (processed - 1) cols.resize((offset + data_size,)) cols[offset:offset + data_size] = _cols vali.resize((offset + data_size,)) vali[offset:offset + data_size] = \ np.random.uniform(size=(data_size,)).astype(np.float32) - indptr[processed:processed + read_lines] = _indptr + offset + indptr[processed:processed + read_lines] = \ + _indptr.astype(np.int64) + offset offset += data_size processed += read_lines pbar.update(processed - 1) if processed == full_num_lines + 1: break h5f.close() + + def convert_bow_to_h5(self, filepath, h5_path): + self.logger.info("convert bow %s to h5 %s", filepath, h5_path) + num_docs, num_words, num_lines = \ + self.obj.read_bag_of_words_header(filepath) + self.logger.info("number of docs: %d, words: %d, nnz: %d", + num_docs, num_words, num_lines) + h5f = h5py.File(h5_path, "w") + rows = h5f.create_dataset("rows", dtype=np.int64, + shape=(num_lines,), chunks=True) + cols = h5f.create_dataset("cols", dtype=np.int32, + shape=(num_lines,), chunks=True) + counts = h5f.create_dataset("counts", dtype=np.float32, + shape=(num_lines,), chunks=True) + vali = h5f.create_dataset("vali", dtype=np.float32, + shape=(num_lines,), chunks=True) + indptr = h5f.create_dataset("indptr", dtype=np.int64, + shape=(num_docs + 1,), chunks=True) + indptr[0] = 0 + processed, recent_row, indptr_offset = 0, 0, 1 + pbar = aux.Progbar(num_lines, unit_name="line") + while processed < num_lines: + # get chunk size + read_lines = min(num_lines - processed, self.opt.chunk_lines) + + # copy rows, cols, counts to h5 + _rows = np.empty((read_lines,), dtype=np.int64) + _cols = np.empty((read_lines,), dtype=np.int32) + _counts = np.empty((read_lines,), dtype=np.float32) + self.obj.read_bag_of_words_content(_rows, _cols, _counts) + rows[processed:processed + read_lines] = _rows + cols[processed:processed + read_lines] = _cols + counts[processed:processed + read_lines] = _counts + vali[processed:processed + read_lines] = \ + np.random.uniform(size=(read_lines,)).astype(np.float32) + + # compute indptr + prev_rows = np.zeros((read_lines,), dtype=np.int64) + prev_rows[1:] = _rows[:-1] + prev_rows[0] = recent_row + diff = _rows - prev_rows + indices = np.where(diff > 0)[0] + _indptr = [] + for idx in indices: + _indptr += ([processed + idx] * diff[idx]) + if _indptr: + indptr[indptr_offset:indptr_offset + len(_indptr)] = \ + np.array(_indptr, dtype=np.int64) + indptr_offset += len(_indptr) + + # udpate processed + processed += read_lines + pbar.update(processed) + recent_row = _rows[-1] + + # finalize indptr + _indptr = [num_lines] * (num_docs + 1 - indptr_offset) + indptr[indptr_offset:num_docs + 1] = np.array(_indptr, dtype=np.int64) + + h5f.close() diff --git a/cusim/proto/config.proto b/cusim/proto/config.proto index f468bb0..794c51f 100644 --- a/cusim/proto/config.proto +++ b/cusim/proto/config.proto @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Jisang Yoon +// Copyright (c) 2021 Jisang Yoon // All rights reserved. // // This source code is licensed under the Apache 2.0 license found in the @@ -6,59 +6,154 @@ syntax = "proto2"; + +// option for data preprocessing message IoUtilsConfigProto { + // logging levels in python and C++ optional int32 py_log_level = 1 [default = 2]; optional int32 c_log_level = 2 [default = 2]; + + // number of chunk lines to preprocess (txt => hdf5 format) data optional int32 chunk_lines = 3 [default = 100000]; + + // number of concurrent threads in data preprocessing optional int32 num_threads = 4 [default = 4]; + + // convert charater to lower case if true + optional bool lower = 5 [default = true]; } -message CuLDAConfigProto { - required string data_path = 7; +// option for LDA model +message CuLDAConfigProto { + // logging levels in python and C++ optional int32 py_log_level = 1 [default = 2]; optional int32 c_log_level = 2 [default = 2]; + // raw data path (format from https://archive.ics.uci.edu/ml/datasets/bag+of+words) + optional string data_path = 7; + + // preprocessed data path (hdf5 format) + // if empty, make temporary directory + optional string processed_data_path = 6; + + // vocabulary path + required string keys_path = 16; + + // skip preprocess (there should be already preprocessed hdf5 format) if true + optional bool skip_preprocess = 8; + + // path to store gamma in E step + // if empty, make temporary directory + optional string gamma_path = 17; + + // reuse gamma from previous epoch if true + // if false, initiate gamma as Figure 6 in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf + optional bool reuse_gamma = 18; + + // number of topics optional int32 num_topics = 3 [default = 10]; + + // block dimension in CUDA + // should be multiple of WARP_SIZE (=32) optional int32 block_dim = 4 [default = 32]; - optional int32 hyper_threads = 5 [default = 10]; - optional string processed_data_dir = 6; - optional bool skip_preprocess = 8; - optional int32 word_min_count = 9 [default = 5]; - optional int32 batch_size = 10 [default = 100000]; + + // set the number blocks as num_blocks * block_dim = physical_cores_in_GPU * hyper_threads + optional int32 hyper_threads = 5 [default = 100]; + + // batch size in training + optional int32 batch_size = 10 [default = 1000000]; + + // number of epochs in training optional int32 epochs = 11 [default = 10]; + + // number of iterations in each E step optional int32 num_iters_in_e_step = 12 [default = 5]; + + // validation ratio, should be between 0 and 1 optional double vali_p = 13 [default = 0.2]; + + // random seed optional int32 seed = 14 [default = 777]; + + // remove all tempory directorys generated by package when program finnished if true + optional bool remove_tmp = 19 [default = true]; + + optional IoUtilsConfigProto io = 15; } -message CuW2VConfigProto { - required string data_path = 7; +// options for loading pretrained w2v model +// can load w2v model file generated by gensim or original w2v code by Google +message W2VPretrainedModel { + optional string filename = 1; + optional bool no_header = 2; + optional bool binary = 3; + optional bool symmetry = 4; +} + +// option for training Word2Vec model +message CuW2VConfigProto { + // logging levels in python and C++ optional int32 py_log_level = 1 [default = 2]; optional int32 c_log_level = 2 [default = 2]; - optional int32 num_dims = 3 [default = 50]; - optional int32 block_dim = 4 [default = 32]; - optional int32 hyper_threads = 5 [default = 10]; + // raw data path (stream txt format) + optional string data_path = 7; + + // path to save preprocessed data (hdf5 format) optional string processed_data_dir = 6; + + // skip data preprocessing (therefore, there should be + // already preprocessed hdf5 format file) if true optional bool skip_preprocess = 8; + + // number of embedding dimensions + optional int32 num_dims = 3 [default = 50]; + + // block_dim in CUDA + optional int32 block_dim = 4 [default = 32]; + + // set number of blocks as num_blocks * block_dim = physical_cores_in_GPU * hyper_threads + optional int32 hyper_threads = 5 [default = 100]; + + // generate vocabulary with words appreared in corpus at least word_min_count times optional int32 word_min_count = 9 [default = 5]; - optional int32 batch_size = 10 [default = 100000]; + + // batch size and number of epochs in training + optional int32 batch_size = 10 [default = 1000000]; optional int32 epochs = 11 [default = 10]; // seed fields optional int32 seed = 14 [default = 777]; - optional int32 table_seed = 15 [default = 777]; - optional int32 cuda_seed = 16 [default = 777]; - optional int32 random_size = 12 [default = 1000000]; + // random table size in negative sampling + optional int32 random_size = 12 [default = 100000000]; + + // number of negative samples + // if zero, it uses hierarchical softmax optional int32 neg = 17 [default = 10]; - // as recommended in w2v paper + + // weight in negative sampling will be word_count ** count_power for each word + // default value 0.75 is recommended in w2v paper optional double count_power = 18 [default = 0.75]; + + // if true, train skip gram model, else train cbow model optional bool skip_gram = 19 [default = true]; - optional bool use_mean = 20 [default = true]; + + // if true, use average context vector in cbow model + // else use summation of context vectors + optional bool cbow_mean = 20 [default = true]; + + // learning rate optional double lr = 21 [default = 0.001]; + + // window size in both skip gram and cbow model optional int32 window_size = 22 [default = 5]; - optional int32 num_threads = 23 [default = 4]; + + // remove all tempory directorys generated by package when program finnished if true + optional bool remove_tmp = 26 [default = true]; + + optional IoUtilsConfigProto io = 24; + optional W2VPretrainedModel pretrained_model = 25; } diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..3386394 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,21 @@ +### How to run example code + +0. install requirements + +```shell +pip install -r requirements.txt +``` + +1. first, it is good to know about python-fire in https://github.com/google/python-fire, if you haven't heard yet. + +2. run w2v experiments on various setting (e.g. skip gram with hierarchical softmax) + +```shell +python example_w2v.py run_experiments --sg0=True --hs0=True +``` + +3. run lda experiments + +```shell +python example_lda.py run_experiments +``` diff --git a/examples/cusim.topics.txt b/examples/cusim.topics.txt new file mode 100644 index 0000000..32bc19a --- /dev/null +++ b/examples/cusim.topics.txt @@ -0,0 +1,650 @@ +================================================== +topic 1 +-------------------------------------------------- +rank 1. car: 0.02677285298705101 +rank 2. vehicle: 0.006062767934054136 +rank 3. wheel: 0.005854051560163498 +rank 4. door: 0.0056894212029874325 +rank 5. vehicles: 0.005506897810846567 +rank 6. model: 0.005505426321178675 +rank 7. seat: 0.00544615276157856 +rank 8. zzz_ford: 0.004928195849061012 +rank 9. truck: 0.00481862248852849 +rank 10. front: 0.004714458715170622 +================================================== +topic 2 +-------------------------------------------------- +rank 1. priest: 0.020068475976586342 +rank 2. church: 0.018575558438897133 +rank 3. abuse: 0.014300045557320118 +rank 4. sexual: 0.012478752993047237 +rank 5. information: 0.011768681928515434 +rank 6. bishop: 0.010295535437762737 +rank 7. privacy: 0.00979470182210207 +rank 8. enditalic: 0.007282644044607878 +rank 9. zzz_government: 0.007169242016971111 +rank 10. beginitalic: 0.007022677455097437 +================================================== +topic 3 +-------------------------------------------------- +rank 1. scientist: 0.012924057431519032 +rank 2. plant: 0.010201558470726013 +rank 3. animal: 0.00955168902873993 +rank 4. human: 0.006574922241270542 +rank 5. water: 0.006187545135617256 +rank 6. species: 0.005247119814157486 +rank 7. science: 0.003869544016197324 +rank 8. research: 0.0037548812106251717 +rank 9. chemical: 0.0036675187293440104 +rank 10. researcher: 0.003376629902049899 +================================================== +topic 4 +-------------------------------------------------- +rank 1. room: 0.0097695617005229 +rank 2. building: 0.009012440219521523 +rank 3. hotel: 0.007701032795011997 +rank 4. town: 0.007017012685537338 +rank 5. visitor: 0.005999790038913488 +rank 6. park: 0.004900350235402584 +rank 7. water: 0.00483303889632225 +rank 8. restaurant: 0.004804808646440506 +rank 9. tour: 0.004689469002187252 +rank 10. house: 0.004657984711229801 +================================================== +topic 5 +-------------------------------------------------- +rank 1. executive: 0.010459953919053078 +rank 2. president: 0.009247427806258202 +rank 3. chief: 0.007263457868248224 +rank 4. deal: 0.006793353706598282 +rank 5. media: 0.006781542673707008 +rank 6. zzz_u_s: 0.006448546424508095 +rank 7. question: 0.006310692522674799 +rank 8. public: 0.0058334809727966785 +rank 9. client: 0.0057419463992118835 +rank 10. com: 0.005699603818356991 +================================================== +topic 6 +-------------------------------------------------- +rank 1. official: 0.010707475244998932 +rank 2. zzz_new_york: 0.010631673038005829 +rank 3. building: 0.006501882337033749 +rank 4. found: 0.005982889328151941 +rank 5. worker: 0.005903987213969231 +rank 6. officer: 0.00562720000743866 +rank 7. hour: 0.005579715128988028 +rank 8. security: 0.0047438195906579494 +rank 9. plane: 0.004530096426606178 +rank 10. attack: 0.0045171682722866535 +================================================== +topic 7 +-------------------------------------------------- +rank 1. gold: 0.008730198256671429 +rank 2. hour: 0.00799502432346344 +rank 3. floor: 0.007750170771032572 +rank 4. medal: 0.005479565821588039 +rank 5. rider: 0.005427593365311623 +rank 6. ice: 0.005187307950109243 +rank 7. event: 0.004164813086390495 +rank 8. silver: 0.00394535344094038 +rank 9. hand: 0.003944935742765665 +rank 10. moment: 0.003745123278349638 +================================================== +topic 8 +-------------------------------------------------- +rank 1. customer: 0.020594391971826553 +rank 2. product: 0.01662433333694935 +rank 3. weather: 0.010188293643295765 +rank 4. stores: 0.009588934481143951 +rank 5. marketing: 0.007573566399514675 +rank 6. consumer: 0.007247460074722767 +rank 7. need: 0.00708211213350296 +rank 8. business: 0.006656122859567404 +rank 9. problem: 0.006193865556269884 +rank 10. sales: 0.00576401362195611 +================================================== +topic 9 +-------------------------------------------------- +rank 1. zzz_enron: 0.0346914604306221 +rank 2. anthrax: 0.018805652856826782 +rank 3. firm: 0.017304280772805214 +rank 4. employees: 0.013712462969124317 +rank 5. accounting: 0.011462894268333912 +rank 6. company: 0.010516936890780926 +rank 7. letter: 0.009165323339402676 +rank 8. zzz_arthur_andersen: 0.008399050682783127 +rank 9. financial: 0.006972334813326597 +rank 10. official: 0.006733026821166277 +================================================== +topic 10 +-------------------------------------------------- +rank 1. game: 0.02054956741631031 +rank 2. yard: 0.01949656568467617 +rank 3. season: 0.018450269475579262 +rank 4. play: 0.01595749706029892 +rank 5. team: 0.014850640669465065 +rank 6. coach: 0.012072306126356125 +rank 7. football: 0.010657819919288158 +rank 8. player: 0.010432523675262928 +rank 9. zzz_nfl: 0.009206585586071014 +rank 10. defensive: 0.008976943790912628 +================================================== +topic 11 +-------------------------------------------------- +rank 1. con: 0.020736297592520714 +rank 2. una: 0.013567320071160793 +rank 3. las: 0.01041751354932785 +rank 4. mas: 0.010156860575079918 +rank 5. dice: 0.009438637644052505 +rank 6. por: 0.00928747933357954 +rank 7. como: 0.008855272084474564 +rank 8. los: 0.008734958246350288 +rank 9. zzz_argentina: 0.0077548702247440815 +rank 10. anos: 0.0052759042009711266 +================================================== +topic 12 +-------------------------------------------------- +rank 1. zzz_afghanistan: 0.02263963408768177 +rank 2. zzz_taliban: 0.019689183682203293 +rank 3. military: 0.014852428808808327 +rank 4. bin: 0.014605461619794369 +rank 5. laden: 0.014458988793194294 +rank 6. war: 0.01199477817863226 +rank 7. zzz_pakistan: 0.01184108667075634 +rank 8. terrorist: 0.011557201854884624 +rank 9. zzz_u_s: 0.01051971036940813 +rank 10. attack: 0.009562982246279716 +================================================== +topic 13 +-------------------------------------------------- +rank 1. court: 0.02521500550210476 +rank 2. case: 0.023994332179427147 +rank 3. lawyer: 0.019229630008339882 +rank 4. trial: 0.012606462463736534 +rank 5. attorney: 0.011963741853833199 +rank 6. law: 0.010776755400002003 +rank 7. prosecutor: 0.010139403864741325 +rank 8. judge: 0.010069739073514938 +rank 9. federal: 0.01000827457755804 +rank 10. charges: 0.009131026454269886 +================================================== +topic 14 +-------------------------------------------------- +rank 1. children: 0.024856505915522575 +rank 2. family: 0.023518990725278854 +rank 3. mother: 0.021585773676633835 +rank 4. parent: 0.018566781654953957 +rank 5. father: 0.017965450882911682 +rank 6. child: 0.016640648245811462 +rank 7. son: 0.014798246324062347 +rank 8. boy: 0.013485484756529331 +rank 9. girl: 0.01209142617881298 +rank 10. daughter: 0.011482957750558853 +================================================== +topic 15 +-------------------------------------------------- +rank 1. home: 0.00700838677585125 +rank 2. run: 0.006053796038031578 +rank 3. right: 0.005981859751045704 +rank 4. left: 0.005203519947826862 +rank 5. part: 0.005086812656372786 +rank 6. night: 0.004532037302851677 +rank 7. put: 0.004220300819724798 +rank 8. took: 0.003913923632353544 +rank 9. called: 0.003663261653855443 +rank 10. early: 0.0034683081321418285 +================================================== +topic 16 +-------------------------------------------------- +rank 1. computer: 0.05519622564315796 +rank 2. system: 0.038603898137807846 +rank 3. zzz_microsoft: 0.0243679191917181 +rank 4. software: 0.02125958725810051 +rank 5. technology: 0.016031846404075623 +rank 6. window: 0.015655480325222015 +rank 7. mail: 0.01430702954530716 +rank 8. user: 0.011626251973211765 +rank 9. information: 0.010091814212501049 +rank 10. operating: 0.00756523571908474 +================================================== +topic 17 +-------------------------------------------------- +rank 1. law: 0.02723626047372818 +rank 2. right: 0.013789551332592964 +rank 3. political: 0.012496591545641422 +rank 4. government: 0.012413491494953632 +rank 5. religious: 0.01058514229953289 +rank 6. immigrant: 0.010227411054074764 +rank 7. power: 0.008888700045645237 +rank 8. ruling: 0.006956734228879213 +rank 9. court: 0.006303566973656416 +rank 10. opposition: 0.006150286644697189 +================================================== +topic 18 +-------------------------------------------------- +rank 1. driver: 0.026438318192958832 +rank 2. car: 0.021013904362916946 +rank 3. race: 0.02072136662900448 +rank 4. racing: 0.013081525452435017 +rank 5. airline: 0.012061871588230133 +rank 6. flight: 0.009761194698512554 +rank 7. track: 0.008779071271419525 +rank 8. races: 0.007440405432134867 +rank 9. airlines: 0.00735550606623292 +rank 10. carrier: 0.0064879427663981915 +================================================== +topic 19 +-------------------------------------------------- +rank 1. zzz_bush: 0.018247155472636223 +rank 2. official: 0.015417532064020634 +rank 3. zzz_united_states: 0.01530768908560276 +rank 4. administration: 0.013708231039345264 +rank 5. leader: 0.010346359573304653 +rank 6. countries: 0.009353779256343842 +rank 7. zzz_u_s: 0.009245852008461952 +rank 8. government: 0.009168907068669796 +rank 9. zzz_iraq: 0.009057393297553062 +rank 10. military: 0.008723369799554348 +================================================== +topic 20 +-------------------------------------------------- +rank 1. percent: 0.044513553380966187 +rank 2. stock: 0.023978371173143387 +rank 3. market: 0.022495266050100327 +rank 4. fund: 0.013825907371938229 +rank 5. billion: 0.012179437093436718 +rank 6. quarter: 0.010966183617711067 +rank 7. investor: 0.01015525683760643 +rank 8. investment: 0.009771433658897877 +rank 9. million: 0.009703823365271091 +rank 10. analyst: 0.00947034452110529 +================================================== +topic 21 +-------------------------------------------------- +rank 1. book: 0.02586548589169979 +rank 2. art: 0.009416724555194378 +rank 3. artist: 0.007856832817196846 +rank 4. collection: 0.007611869368702173 +rank 5. painting: 0.0066984654404222965 +rank 6. fashion: 0.005222611129283905 +rank 7. century: 0.005118153523653746 +rank 8. writer: 0.004741827957332134 +rank 9. designer: 0.004720605909824371 +rank 10. author: 0.004426541272550821 +================================================== +topic 22 +-------------------------------------------------- +rank 1. music: 0.03731286898255348 +rank 2. song: 0.023323602974414825 +rank 3. cell: 0.015249419026076794 +rank 4. album: 0.011770840734243393 +rank 5. band: 0.011705778539180756 +rank 6. musical: 0.008127620443701744 +rank 7. singer: 0.006815788336098194 +rank 8. concert: 0.006784858647733927 +rank 9. jazz: 0.006698825396597385 +rank 10. sound: 0.006471691187471151 +================================================== +topic 23 +-------------------------------------------------- +rank 1. web: 0.04922621324658394 +rank 2. site: 0.03805321082472801 +rank 3. www: 0.03708707541227341 +rank 4. com: 0.03255585581064224 +rank 5. online: 0.027454305440187454 +rank 6. zzz_internet: 0.019746430218219757 +rank 7. sites: 0.018789643421769142 +rank 8. information: 0.012109276838600636 +rank 9. mail: 0.010703440755605698 +rank 10. internet: 0.010465497151017189 +================================================== +topic 24 +-------------------------------------------------- +rank 1. cup: 0.013092475943267345 +rank 2. food: 0.011349334381520748 +rank 3. minutes: 0.008257454261183739 +rank 4. add: 0.007631846237927675 +rank 5. tablespoon: 0.006674299016594887 +rank 6. oil: 0.006410549394786358 +rank 7. pepper: 0.005671842489391565 +rank 8. sugar: 0.005601006560027599 +rank 9. teaspoon: 0.005426750052720308 +rank 10. water: 0.005266525782644749 +================================================== +topic 25 +-------------------------------------------------- +rank 1. team: 0.03806942701339722 +rank 2. season: 0.0169094055891037 +rank 3. games: 0.014860374853014946 +rank 4. zzz_olympic: 0.013387414626777172 +rank 5. coach: 0.011522733606398106 +rank 6. zzz_miami: 0.00958396214991808 +rank 7. athletes: 0.009186827577650547 +rank 8. player: 0.00904573779553175 +rank 9. football: 0.008782983757555485 +rank 10. defense: 0.007022276986390352 +================================================== +topic 26 +-------------------------------------------------- +rank 1. zzz_united_states: 0.00901876762509346 +rank 2. zzz_american: 0.008501513861119747 +rank 3. american: 0.008129569701850414 +rank 4. country: 0.006140291225165129 +rank 5. government: 0.005337539594620466 +rank 6. group: 0.005324787925928831 +rank 7. german: 0.0052407230250537395 +rank 8. history: 0.004972025752067566 +rank 9. french: 0.0047142705880105495 +rank 10. family: 0.004625052213668823 +================================================== +topic 27 +-------------------------------------------------- +rank 1. women: 0.07863874733448029 +rank 2. gay: 0.020628171041607857 +rank 3. dog: 0.014878431335091591 +rank 4. magazine: 0.01347420085221529 +rank 5. woman: 0.012085708789527416 +rank 6. sex: 0.009264894761145115 +rank 7. female: 0.008259394206106663 +rank 8. cat: 0.006200404372066259 +rank 9. male: 0.0057057044468820095 +rank 10. lesbian: 0.0040387725457549095 +================================================== +topic 28 +-------------------------------------------------- +rank 1. digital: 0.011757075786590576 +rank 2. screen: 0.0080463457852602 +rank 3. wine: 0.007102092728018761 +rank 4. device: 0.006819858215749264 +rank 5. wines: 0.0068092974834144115 +rank 6. chip: 0.006679498124867678 +rank 7. computer: 0.006480266340076923 +rank 8. devices: 0.005909178406000137 +rank 9. electronic: 0.0056115672923624516 +rank 10. images: 0.004710317123681307 +================================================== +topic 29 +-------------------------------------------------- +rank 1. campaign: 0.03327873349189758 +rank 2. political: 0.014918365515768528 +rank 3. democratic: 0.014790846966207027 +rank 4. election: 0.014583878219127655 +rank 5. republican: 0.014538025483489037 +rank 6. voter: 0.01402147114276886 +rank 7. zzz_al_gore: 0.013029148802161217 +rank 8. zzz_party: 0.012214157730340958 +rank 9. zzz_republican: 0.011119640432298183 +rank 10. candidates: 0.010824044235050678 +================================================== +topic 30 +-------------------------------------------------- +rank 1. school: 0.03626062348484993 +rank 2. student: 0.021992284804582596 +rank 3. black: 0.015230956487357616 +rank 4. group: 0.013538197614252567 +rank 5. public: 0.010991621762514114 +rank 6. percent: 0.010974901728332043 +rank 7. zzz_texas: 0.008697726763784885 +rank 8. gun: 0.007661579176783562 +rank 9. member: 0.0075561245903372765 +rank 10. white: 0.007528342306613922 +================================================== +topic 31 +-------------------------------------------------- +rank 1. zzz_fbi: 0.025642145425081253 +rank 2. fish: 0.020048771053552628 +rank 3. bird: 0.013764469884335995 +rank 4. agent: 0.011454230174422264 +rank 5. irish: 0.009724821895360947 +rank 6. fishing: 0.00831819698214531 +rank 7. zzz_timothy_mcveigh: 0.006179510150104761 +rank 8. zzz_brazil: 0.006174848414957523 +rank 9. hijacker: 0.0060051921755075455 +rank 10. zzz_simon: 0.005628513637930155 +================================================== +topic 32 +-------------------------------------------------- +rank 1. company: 0.07715368270874023 +rank 2. companies: 0.033467356115579605 +rank 3. business: 0.019932780414819717 +rank 4. million: 0.01110815443098545 +rank 5. deal: 0.01099175214767456 +rank 6. executives: 0.010963932611048222 +rank 7. executive: 0.010428434237837791 +rank 8. market: 0.0098022585734725 +rank 9. stock: 0.009284550324082375 +rank 10. chief: 0.008711854927241802 +================================================== +topic 33 +-------------------------------------------------- +rank 1. consumer: 0.02195882610976696 +rank 2. percent: 0.020870916545391083 +rank 3. companies: 0.015635766088962555 +rank 4. industry: 0.015347079373896122 +rank 5. market: 0.014645704068243504 +rank 6. cost: 0.012568947859108448 +rank 7. customer: 0.012199653312563896 +rank 8. prices: 0.010143699124455452 +rank 9. high: 0.009660380892455578 +rank 10. worker: 0.006465692073106766 +================================================== +topic 34 +-------------------------------------------------- +rank 1. season: 0.021334033459424973 +rank 2. team: 0.016839321702718735 +rank 3. game: 0.014815553091466427 +rank 4. inning: 0.014347057789564133 +rank 5. player: 0.013774506747722626 +rank 6. yankees: 0.011174232698976994 +rank 7. run: 0.010817022994160652 +rank 8. baseball: 0.01055373065173626 +rank 9. games: 0.010321191512048244 +rank 10. hit: 0.010284436866641045 +================================================== +topic 35 +-------------------------------------------------- +rank 1. zzz_george_bush: 0.05796745792031288 +rank 2. zzz_al_gore: 0.04237228259444237 +rank 3. election: 0.022491727024316788 +rank 4. president: 0.020312432199716568 +rank 5. ballot: 0.019908472895622253 +rank 6. zzz_florida: 0.016183944419026375 +rank 7. presidential: 0.015332216396927834 +rank 8. votes: 0.01442129909992218 +rank 9. vote: 0.009808804839849472 +rank 10. zzz_bush: 0.00961968582123518 +================================================== +topic 36 +-------------------------------------------------- +rank 1. palestinian: 0.02687947452068329 +rank 2. zzz_israel: 0.023833250626921654 +rank 3. zzz_israeli: 0.013304143212735653 +rank 4. soldier: 0.010826818645000458 +rank 5. peace: 0.010164049454033375 +rank 6. zzz_yasser_arafat: 0.009658769704401493 +rank 7. israeli: 0.009265914559364319 +rank 8. war: 0.00923923496156931 +rank 9. israelis: 0.008119330741465092 +rank 10. military: 0.007811776362359524 +================================================== +topic 37 +-------------------------------------------------- +rank 1. death: 0.023664837703108788 +rank 2. prison: 0.016880618408322334 +rank 3. murder: 0.01633421890437603 +rank 4. book: 0.009351547807455063 +rank 5. killed: 0.009010221809148788 +rank 6. prisoner: 0.007692103274166584 +rank 7. killing: 0.007337935268878937 +rank 8. woman: 0.007256744429469109 +rank 9. victim: 0.007001840975135565 +rank 10. shooting: 0.006456068251281977 +================================================== +topic 38 +-------------------------------------------------- +rank 1. million: 0.01617966778576374 +rank 2. newspaper: 0.009461159817874432 +rank 3. show: 0.006403861101716757 +rank 4. program: 0.005598483607172966 +rank 5. network: 0.0053542195819318295 +rank 6. money: 0.00485030934214592 +rank 7. according: 0.004323051776736975 +rank 8. special: 0.0040418170392513275 +rank 9. help: 0.004037346225231886 +rank 10. past: 0.0039222449995577335 +================================================== +topic 39 +-------------------------------------------------- +rank 1. show: 0.022530050948262215 +rank 2. character: 0.009580017998814583 +rank 3. audience: 0.005444356705993414 +rank 4. television: 0.004325090907514095 +rank 5. series: 0.004303744062781334 +rank 6. look: 0.004119543824344873 +rank 7. love: 0.00407353974878788 +rank 8. film: 0.004058054182678461 +rank 9. find: 0.003848094493150711 +rank 10. young: 0.0036786773707717657 +================================================== +topic 40 +-------------------------------------------------- +rank 1. drug: 0.047516606748104095 +rank 2. government: 0.012602291069924831 +rank 3. zzz_aid: 0.01227615773677826 +rank 4. zzz_india: 0.010664834640920162 +rank 5. countries: 0.008103608153760433 +rank 6. million: 0.007103894371539354 +rank 7. food: 0.006576470099389553 +rank 8. farmer: 0.006402278784662485 +rank 9. country: 0.006317282561212778 +rank 10. zzz_united_states: 0.0062563237734138966 +================================================== +topic 41 +-------------------------------------------------- +rank 1. game: 0.026529431343078613 +rank 2. player: 0.022719431668519974 +rank 3. games: 0.0206462275236845 +rank 4. sport: 0.016915155574679375 +rank 5. fan: 0.012125855311751366 +rank 6. soccer: 0.011505456641316414 +rank 7. video: 0.010653939098119736 +rank 8. zzz_nbc: 0.009938360191881657 +rank 9. zzz_nba: 0.009428229182958603 +rank 10. team: 0.008263841271400452 +================================================== +topic 42 +-------------------------------------------------- +rank 1. tax: 0.04971655085682869 +rank 2. cut: 0.026394149288535118 +rank 3. economy: 0.0230980534106493 +rank 4. economic: 0.017415864393115044 +rank 5. zzz_mexico: 0.01618388667702675 +rank 6. government: 0.01595328189432621 +rank 7. taxes: 0.014780825935304165 +rank 8. spending: 0.01243556011468172 +rank 9. income: 0.012374772690236568 +rank 10. zzz_social_security: 0.010477164760231972 +================================================== +topic 43 +-------------------------------------------------- +rank 1. zzz_bush: 0.027270827442407608 +rank 2. bill: 0.024806691333651543 +rank 3. zzz_congress: 0.018335092812776566 +rank 4. zzz_white_house: 0.016858264803886414 +rank 5. federal: 0.01354345865547657 +rank 6. zzz_senate: 0.01329002995043993 +rank 7. plan: 0.012937983497977257 +rank 8. proposal: 0.010213974863290787 +rank 9. administration: 0.009349077008664608 +rank 10. health: 0.008263114839792252 +================================================== +topic 44 +-------------------------------------------------- +rank 1. point: 0.020692508667707443 +rank 2. team: 0.018113387748599052 +rank 3. game: 0.015103872865438461 +rank 4. season: 0.013727625831961632 +rank 5. play: 0.012306117452681065 +rank 6. goal: 0.012093267403542995 +rank 7. games: 0.011415580287575722 +rank 8. shot: 0.011306485161185265 +rank 9. king: 0.011238034814596176 +rank 10. player: 0.008728481829166412 +================================================== +topic 45 +-------------------------------------------------- +rank 1. player: 0.013769405893981457 +rank 2. point: 0.012727474793791771 +rank 3. win: 0.012649298645555973 +rank 4. play: 0.011700315400958061 +rank 5. round: 0.010591110214591026 +rank 6. season: 0.010317614302039146 +rank 7. shot: 0.01031588576734066 +rank 8. game: 0.00999273732304573 +rank 9. team: 0.009904314763844013 +rank 10. final: 0.009542282670736313 +================================================== +topic 46 +-------------------------------------------------- +rank 1. zzz_china: 0.015810564160346985 +rank 2. oil: 0.014123033732175827 +rank 3. power: 0.013019545003771782 +rank 4. zzz_russia: 0.012522333301603794 +rank 5. energy: 0.01063102949410677 +rank 6. plant: 0.010357524268329144 +rank 7. gas: 0.00931472983211279 +rank 8. nuclear: 0.008214462548494339 +rank 9. missile: 0.007829232141375542 +rank 10. environmental: 0.007554346229881048 +================================================== +topic 47 +-------------------------------------------------- +rank 1. com: 0.02512955479323864 +rank 2. zzz_laker: 0.015019885264337063 +rank 3. palm: 0.013598510064184666 +rank 4. daily: 0.013184287585318089 +rank 5. statesman: 0.013182769529521465 +rank 6. beach: 0.01314060389995575 +rank 7. question: 0.010342201218008995 +rank 8. zzz_eastern: 0.009052561596035957 +rank 9. information: 0.008214504458010197 +rank 10. austin: 0.007981293834745884 +================================================== +topic 48 +-------------------------------------------------- +rank 1. film: 0.034848302602767944 +rank 2. movie: 0.02526075392961502 +rank 3. actor: 0.013231894932687283 +rank 4. movies: 0.008959283120930195 +rank 5. zzz_hollywood: 0.008070441894233227 +rank 6. play: 0.007740044500678778 +rank 7. theater: 0.00727312033995986 +rank 8. director: 0.005834080744534731 +rank 9. character: 0.005199376493692398 +rank 10. zzz_oscar: 0.004690317437052727 +================================================== +topic 49 +-------------------------------------------------- +rank 1. patient: 0.02304932475090027 +rank 2. doctor: 0.01952706277370453 +rank 3. cancer: 0.011629555374383926 +rank 4. medical: 0.011445121839642525 +rank 5. disease: 0.011433145962655544 +rank 6. hospital: 0.009982189163565636 +rank 7. study: 0.008990893140435219 +rank 8. treatment: 0.007559608668088913 +rank 9. blood: 0.007204002235084772 +rank 10. test: 0.007000159937888384 +================================================== +topic 50 +-------------------------------------------------- +rank 1. million: 0.028744814917445183 +rank 2. contract: 0.016937075182795525 +rank 3. agent: 0.009414087980985641 +rank 4. manager: 0.007703984156250954 +rank 5. business: 0.006961227394640446 +rank 6. high: 0.005569536704570055 +rank 7. club: 0.005377542693167925 +rank 8. past: 0.005371585488319397 +rank 9. career: 0.005363883450627327 +rank 10. hand: 0.005337761249393225 diff --git a/examples/example1.py b/examples/example1.py deleted file mode 100644 index 1bea971..0000000 --- a/examples/example1.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2021 Jisang Yoon -# All rights reserved. -# -# This source code is licensed under the Apache 2.0 license found in the -# LICENSE file in the root directory of this source tree. - -# pylint: disable=no-name-in-module,logging-format-truncated -import os -import subprocess -import fire - -import h5py -import numpy as np -from gensim import downloader as api -from cusim import aux, IoUtils, CuLDA, CuW2V - -LOGGER = aux.get_logger() -DOWNLOAD_PATH = "./res" -# DATASET = "wiki-english-20171001" -DATASET = "quora-duplicate-questions" -DATA_PATH = f"./res/{DATASET}.stream.txt" -LDA_PATH = f"./res/{DATASET}-lda.h5" -PROCESSED_DATA_DIR = f"./res/{DATASET}-converted" -MIN_COUNT = 5 -TOPK = 10 - -def download(): - if os.path.exists(DATA_PATH): - LOGGER.info("%s already exists", DATA_PATH) - return - api.BASE_DIR = DOWNLOAD_PATH - filepath = api.load(DATASET, return_path=True) - LOGGER.info("filepath: %s", filepath) - cmd = ["gunzip", "-c", filepath, ">", DATA_PATH] - cmd = " ".join(cmd) - LOGGER.info("cmd: %s", cmd) - subprocess.call(cmd, shell=True) - -def run_io(): - download() - iou = IoUtils(opt={"chunk_lines": 10000, "num_threads": 8}) - iou.convert_stream_to_h5(DATA_PATH, 5, PROCESSED_DATA_DIR) - - -def run_lda(): - opt = { - "data_path": DATA_PATH, - "processed_data_dir": PROCESSED_DATA_DIR, - # "skip_preprocess":True, - } - lda = CuLDA(opt) - lda.train_model() - lda.save_model(LDA_PATH) - h5f = h5py.File(LDA_PATH, "r") - beta = h5f["beta"][:] - word_list = h5f["keys"][:] - num_topics = h5f["alpha"].shape[0] - for i in range(num_topics): - print("=" * 50) - print(f"topic {i + 1}") - words = np.argsort(-beta.T[i])[:10] - print("-" * 50) - for j in range(TOPK): - word = word_list[words[j]].decode("utf8") - prob = beta[words[j], i] - print(f"rank {j + 1}. word: {word}, prob: {prob}") - h5f.close() - -def run_w2v(): - opt = { - # "c_log_level": 3, - "data_path": DATA_PATH, - "processed_data_dir": PROCESSED_DATA_DIR, - # "skip_preprocess":True, - } - w2v = CuW2V(opt) - w2v.train_model() - -if __name__ == "__main__": - fire.Fire() diff --git a/examples/example_lda.py b/examples/example_lda.py new file mode 100644 index 0000000..a276643 --- /dev/null +++ b/examples/example_lda.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,logging-format-truncated +# pylint: disable=too-few-public-methods +import os +from os.path import join as pjoin +import time +import pickle +import subprocess + +import tqdm +import fire +import wget +import h5py +import numpy as np +import pandas as pd + +# import gensim +from gensim.models.ldamulticore import LdaMulticore + +from cusim import aux, CuLDA + +LOGGER = aux.get_logger() +# DATASET = "nips" +DATASET = "nytimes" +DIR_PATH = "./res" +BASE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/" \ + "bag-of-words/" + +def download(): + if not os.path.exists(DIR_PATH): + os.makedirs(DIR_PATH, exist_ok=True) + + if os.path.exists(pjoin(DIR_PATH, f"docword.{DATASET}.txt")): + LOGGER.info("path %s already exists", + pjoin(DIR_PATH, f"docword.{DATASET}.txt")) + return + + # download docword + filename = f"docword.{DATASET}.txt.gz" + out_path = pjoin(DIR_PATH, filename) + LOGGER.info("download %s to %s", BASE_URL + filename, out_path) + wget.download(BASE_URL + filename, out=out_path) + print() + + # decompress + cmd = ["gunzip", "-c", out_path, ">", + pjoin(DIR_PATH, f"docword.{DATASET}.txt")] + cmd = " ".join(cmd) + subprocess.call(cmd, shell=True) + os.remove(pjoin(DIR_PATH, filename)) + + # download vocab + filename = f"vocab.{DATASET}.txt" + out_path = pjoin(DIR_PATH, filename) + LOGGER.info("download %s to %s", BASE_URL + filename, out_path) + wget.download(BASE_URL + filename, out=out_path) + print() + +def run_cusim(): + download() + data_path = pjoin(DIR_PATH, f"docword.{DATASET}.txt") + keys_path = pjoin(DIR_PATH, f"vocab.{DATASET}.txt") + processed_data_path = pjoin(DIR_PATH, f"docword.{DATASET}.h5") + opt = { + "data_path": data_path, + "processed_data_path": processed_data_path, + "keys_path": keys_path, + "num_topics": 50, + "num_iters_in_e_step": 10, + "reuse_gamma": True, + # "skip_preprocess": os.path.exists(processed_data_path), + } + start = time.time() + lda = CuLDA(opt) + lda.train_model() + el0 = time.time() - start + LOGGER.info("elapsed for training LDA using cusim: %.4e sec", el0) + h5_model_path = pjoin(DIR_PATH, "cusim.lda.model.h5") + lda.save_h5_model(h5_model_path) + show_cusim_topics(h5_model_path) + return el0 + +def show_cusim_topics(h5_model_path, topk=10): + h5f = h5py.File(h5_model_path, "r") + beta = h5f["beta"][:, :].T + keys = h5f["keys"][:] + show_topics(beta, keys, topk, "cusim.topics.txt") + +def build_gensim_corpus(): + corpus_path = pjoin(DIR_PATH, f"docword.{DATASET}.pk") + if os.path.exists(corpus_path): + LOGGER.info("load corpus from %s", corpus_path) + with open(corpus_path, "rb") as fin: + ret = pickle.loads(fin.read()) + return ret + + # get corpus for gensim lda + data_path = pjoin(DIR_PATH, f"docword.{DATASET}.txt") + LOGGER.info("build corpus from %s", data_path) + docs, doc, curid = [], [], -1 + with open(data_path, "r") as fin: + for idx, line in tqdm.tqdm(enumerate(fin)): + if idx < 3: + continue + docid, wordid, count = line.strip().split() + # zero-base id + docid, wordid, count = int(docid) - 1, int(wordid) - 1, float(count) + if 0 <= curid < docid: + docs.append(doc) + doc = [] + doc.append((wordid, count)) + curid = docid + docs.append(doc) + LOGGER.info("save corpus to %s", corpus_path) + with open(corpus_path, "wb") as fout: + fout.write(pickle.dumps(docs, 2)) + return docs + +def run_gensim(): + docs = build_gensim_corpus() + keys_path = pjoin(DIR_PATH, f"vocab.{DATASET}.txt") + LOGGER.info("load vocab from %s", keys_path) + id2word = {} + with open(keys_path, "rb") as fin: + for idx, line in enumerate(fin): + id2word[idx] = line.strip() + + start = time.time() + lda = LdaMulticore(docs, num_topics=50, workers=None, + id2word=id2word, iterations=10) + el0 = time.time() - start + LOGGER.info("elapsed for training lda using gensim: %.4e sec", el0) + model_path = pjoin(DIR_PATH, "gensim.lda.model") + LOGGER.info("save gensim lda model to %s", model_path) + lda.save(model_path) + show_gensim_topics(model_path) + return el0 + +def show_gensim_topics(model_path=None, topk=10): + # load beta + model_path = model_path or pjoin(DIR_PATH, "gensim.lda.model") + LOGGER.info("load gensim lda model from %s", model_path) + lda = LdaMulticore.load(model_path) + beta = lda.state.get_lambda() + beta /= np.sum(beta, axis=1)[:, None] + + # load keys + keys_path = pjoin(DIR_PATH, f"vocab.{DATASET}.txt") + LOGGER.info("load vocab from %s", keys_path) + with open(keys_path, "rb") as fin: + keys = [line.strip() for line in fin] + show_topics(beta, keys, topk, "gensim.topics.txt") + +def show_topics(beta, keys, topk, result_path): + LOGGER.info("save results to %s (topk: %d)", result_path, topk) + fout = open(result_path, "w") + for idx in range(beta.shape[0]): + print("=" * 50) + fout.write("=" * 50 + "\n") + print(f"topic {idx + 1}") + fout.write(f"topic {idx + 1}" + "\n") + print("-" * 50) + fout.write("-" * 50 + "\n") + _beta = beta[idx, :] + indices = np.argsort(-_beta)[:topk] + for rank, wordid in enumerate(indices): + word = keys[wordid].decode("utf8") + prob = _beta[wordid] + print(f"rank {rank + 1}. {word}: {prob}") + fout.write(f"rank {rank + 1}. {word}: {prob}" + "\n") + fout.close() + + +def run_experiments(): + training_time = {"attr": "training time (sec)"} + training_time["gensim (8 vpus)"] = run_gensim() + training_time["cusim"] = run_cusim() + df0 = pd.DataFrame([training_time]) + df0.set_index("attr", inplace=True) + print(df0.to_markdown()) + + +if __name__ == "__main__": + fire.Fire() diff --git a/examples/example_w2v.py b/examples/example_w2v.py new file mode 100644 index 0000000..77de2c4 --- /dev/null +++ b/examples/example_w2v.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 Jisang Yoon +# All rights reserved. +# +# This source code is licensed under the Apache 2.0 license found in the +# LICENSE file in the root directory of this source tree. + +# pylint: disable=no-name-in-module,logging-format-truncated +# pylint: disable=too-few-public-methods +import os +import time +import subprocess + +import tqdm +import fire +import pandas as pd + +import gensim +from gensim import downloader as api +from gensim.test.utils import datapath + +import nltk +from nltk.tokenize import RegexpTokenizer + +from cusim import aux, CuW2V + + +LOGGER = aux.get_logger() +DOWNLOAD_PATH = "./res" +DATASET = "quora-duplicate-questions" +DATA_PATH = f"./res/{DATASET}.stream.txt" +PROCESSED_DATA_DIR = "./res/{DATASET}-processed" +CUSIM_MODEL = "./res/cusim.w2v.model" +GENSIM_MODEL = "./res/gensim.w2v.model" + + +# common hyperparameters +MIN_COUNT = 5 +LEARNING_RATE = 0.001 +NEG_SIZE = 10 +NUM_DIMS = 100 +CBOW_MEAN = False +EPOCHS = 10 + + +def download(): + if os.path.exists(DATA_PATH): + LOGGER.info("%s already exists", DATA_PATH) + return + if not os.path.exists(DOWNLOAD_PATH): + os.makedirs(DOWNLOAD_PATH, exist_ok=True) + api.BASE_DIR = DOWNLOAD_PATH + filepath = api.load(DATASET, return_path=True) + LOGGER.info("filepath: %s", filepath) + cmd = ["gunzip", "-c", filepath, ">", DATA_PATH] + cmd = " ".join(cmd) + LOGGER.info("cmd: %s", cmd) + subprocess.call(cmd, shell=True) + preprocess_data() + +def preprocess_data(): + tokenizer = RegexpTokenizer(r'\w+') + nltk.download("wordnet") + lemmatizer = nltk.stem.wordnet.WordNetLemmatizer() + fout = open(DATA_PATH + ".tmp", "wb") + with open(DATA_PATH, "rb") as fin: + for line in tqdm.tqdm(fin): + line = line.decode("utf8").strip() + line = preprocess_line(line, tokenizer, lemmatizer) + fout.write((line + "\n").encode("utf8")) + fout.close() + os.rename(DATA_PATH + ".tmp", DATA_PATH) + +def preprocess_line(line, tokenizer, lemmatizer): + line = line.lower() + line = tokenizer.tokenize(line) + line = [token for token in line + if not token.isnumeric() and len(token) > 1] + line = [lemmatizer.lemmatize(token) for token in line] + return " ".join(line) + +def run_cusim(skip_gram=False, hierarchical_softmax=False): + download() + opt = { + "data_path": DATA_PATH, + "processed_data_dir": PROCESSED_DATA_DIR, + # "skip_preprocess": os.path.exists(PROCESSED_DATA_DIR), + "num_dims": NUM_DIMS, + "epochs": EPOCHS, + "word_min_count": MIN_COUNT, + "lr": 0.001, + "io": { + "lower": False + }, + "neg": 0 if hierarchical_softmax else NEG_SIZE, + "skip_gram": skip_gram, + "cbow_mean": CBOW_MEAN, + } + start = time.time() + w2v = CuW2V(opt) + w2v.train_model() + elapsed = time.time() - start + LOGGER.info("elapsed for cusim w2v training: %.4e sec", elapsed) + w2v.save_word2vec_format(CUSIM_MODEL, binary=False) + return elapsed, evaluate_w2v_model(CUSIM_MODEL) + +def run_gensim(skip_gram=False, hierarchical_softmax=False, workers=8): + download() + start = time.time() + model = gensim.models.Word2Vec(corpus_file=DATA_PATH, workers=workers, + sg=skip_gram, hs=hierarchical_softmax, + min_alpha=LEARNING_RATE, min_count=MIN_COUNT, + alpha=LEARNING_RATE, negative=NEG_SIZE, + iter=EPOCHS, cbow_mean=CBOW_MEAN, + size=NUM_DIMS) + elapsed = time.time() - start + LOGGER.info("elapsed for gensim w2v training: %.4e sec", elapsed) + model.wv.save_word2vec_format(GENSIM_MODEL, binary=False) + LOGGER.info("gensim w2v model is saved to %s", GENSIM_MODEL) + return elapsed, evaluate_w2v_model(GENSIM_MODEL) + +def evaluate_w2v_model(model=GENSIM_MODEL): + LOGGER.info("load word2vec format model from %s", model) + model = gensim.models.KeyedVectors.load_word2vec_format(model) + results = model.wv.evaluate_word_pairs(datapath("wordsim353.tsv"), + case_insensitive=False) + LOGGER.info("evaluation results: %s", results) + return results + +# gpu model variable is for being displayed in markdown +# please put the real gpu modelname +def run_experiments(skip_gram=False, hierarchical_softmax=False, + gpu_model="NVIDIA T4"): + training_time = {"attr": "training time (sec)"} + pearson = {"attr": "pearson"} + spearman = {"attr": "spearman"} + for i in [1, 2, 4, 8]: + elapsed, evals = run_gensim(skip_gram, hierarchical_softmax, i) + training_time[f"{i} workers (gensim)"] = elapsed + pearson[f"{i} workers (gensim)"] = evals[0][0] + spearman[f"{i} workers (gensim)"] = evals[1][0] + elapsed, evals = run_cusim(skip_gram, hierarchical_softmax) + gpu_title = f"{gpu_model} (cusim)" + training_time[gpu_title] = elapsed + pearson[gpu_title] = evals[0][0] + spearman[gpu_title] = evals[1][0] + df0 = pd.DataFrame([training_time, pearson, spearman]) + df0.set_index("attr", inplace=True) + print(df0.to_markdown()) + +# gpu model variable is for being displayed in markdown +# please put the real gpu modelname +def run_various_experiments(gpu_model="NVIDIA T4"): + for sg0 in [True, False]: + for hs0 in [True, False]: + print("=" * 100) + LOGGER.info("setting: %s, %s", + "skip gram" if sg0 else "cbow", + "hierarchical softmax" if hs0 else "negative sampling") + run_experiments(sg0, hs0, gpu_model) + + +if __name__ == "__main__": + fire.Fire() diff --git a/examples/gensim.topics.txt b/examples/gensim.topics.txt new file mode 100644 index 0000000..9dad07e --- /dev/null +++ b/examples/gensim.topics.txt @@ -0,0 +1,650 @@ +================================================== +topic 1 +-------------------------------------------------- +rank 1. anthrax: 0.00685924245044589 +rank 2. drug: 0.00645965151488781 +rank 3. cell: 0.005628286395221949 +rank 4. disease: 0.004471580497920513 +rank 5. research: 0.004018326755613089 +rank 6. patient: 0.0039655729196965694 +rank 7. scientist: 0.003750465577468276 +rank 8. human: 0.0035445094108581543 +rank 9. doctor: 0.0033765030093491077 +rank 10. official: 0.003303991165012121 +================================================== +topic 2 +-------------------------------------------------- +rank 1. church: 0.003922631964087486 +rank 2. priest: 0.0034748397301882505 +rank 3. bishop: 0.0029971767216920853 +rank 4. country: 0.0026398426853120327 +rank 5. children: 0.00246992870233953 +rank 6. zzz_church: 0.0023986543528735638 +rank 7. member: 0.001985707087442279 +rank 8. friend: 0.0019694953225553036 +rank 9. school: 0.001964277820661664 +rank 10. family: 0.0019087349064648151 +================================================== +topic 3 +-------------------------------------------------- +rank 1. company: 0.016529614105820656 +rank 2. zzz_enron: 0.009667424485087395 +rank 3. companies: 0.008695845492184162 +rank 4. percent: 0.006433089263737202 +rank 5. market: 0.005390912294387817 +rank 6. million: 0.005133198108524084 +rank 7. billion: 0.004916781093925238 +rank 8. business: 0.004866308066993952 +rank 9. industry: 0.0043991943821311 +rank 10. firm: 0.003601119853556156 +================================================== +topic 4 +-------------------------------------------------- +rank 1. employees: 0.0041818018071353436 +rank 2. million: 0.0033566232305020094 +rank 3. percent: 0.0033026104792952538 +rank 4. airport: 0.0030971942469477654 +rank 5. business: 0.0030223086941987276 +rank 6. federal: 0.0028500158805400133 +rank 7. passenger: 0.0027126194909214973 +rank 8. company: 0.0027043502777814865 +rank 9. law: 0.002688799751922488 +rank 10. job: 0.0026725251227617264 +================================================== +topic 5 +-------------------------------------------------- +rank 1. million: 0.005916924215853214 +rank 2. percent: 0.00542123056948185 +rank 3. company: 0.004422706551849842 +rank 4. online: 0.004064733628183603 +rank 5. com: 0.0037389120552688837 +rank 6. site: 0.003357403911650181 +rank 7. money: 0.0030307266861200333 +rank 8. web: 0.0026095015928149223 +rank 9. internet: 0.0025709569454193115 +rank 10. business: 0.0024067554622888565 +================================================== +topic 6 +-------------------------------------------------- +rank 1. government: 0.0027651293203234673 +rank 2. official: 0.002383399987593293 +rank 3. women: 0.0023563546128571033 +rank 4. right: 0.0021000944543629885 +rank 5. attack: 0.00208303309045732 +rank 6. president: 0.002063552848994732 +rank 7. part: 0.002001311630010605 +rank 8. military: 0.0019593038596212864 +rank 9. home: 0.0019395806593820453 +rank 10. died: 0.0018545138882473111 +================================================== +topic 7 +-------------------------------------------------- +rank 1. zzz_al_gore: 0.00505444873124361 +rank 2. campaign: 0.004116986878216267 +rank 3. president: 0.003524355124682188 +rank 4. percent: 0.0024283165112137794 +rank 5. need: 0.002368952613323927 +rank 6. zzz_china: 0.0022500273771584034 +rank 7. zzz_bush: 0.002205689437687397 +rank 8. right: 0.0021621366031467915 +rank 9. zzz_george_bush: 0.002156849019229412 +rank 10. government: 0.002107091713696718 +================================================== +topic 8 +-------------------------------------------------- +rank 1. official: 0.0031164248939603567 +rank 2. book: 0.0027842181734740734 +rank 3. care: 0.002501435810700059 +rank 4. children: 0.0024145750794559717 +rank 5. patient: 0.0023914289195090532 +rank 6. job: 0.0022276851814240217 +rank 7. million: 0.002185588236898184 +rank 8. law: 0.0021591810509562492 +rank 9. percent: 0.0021459299605339766 +rank 10. high: 0.0020923956762999296 +================================================== +topic 9 +-------------------------------------------------- +rank 1. bill: 0.007710717618465424 +rank 2. zzz_bush: 0.007315776310861111 +rank 3. zzz_senate: 0.006834658794105053 +rank 4. zzz_white_house: 0.006023857276886702 +rank 5. president: 0.005438356660306454 +rank 6. campaign: 0.00514236930757761 +rank 7. republican: 0.004330878611654043 +rank 8. zzz_congress: 0.004327962175011635 +rank 9. law: 0.004060880281031132 +rank 10. election: 0.0037999162450432777 +================================================== +topic 10 +-------------------------------------------------- +rank 1. patient: 0.004987091291695833 +rank 2. percent: 0.003388918936252594 +rank 3. women: 0.003210175782442093 +rank 4. doctor: 0.0030205147340893745 +rank 5. need: 0.0028614287730306387 +rank 6. school: 0.0028380276635289192 +rank 7. children: 0.002690003952011466 +rank 8. problem: 0.002651546150445938 +rank 9. high: 0.0025971722789108753 +rank 10. help: 0.002436841605231166 +================================================== +topic 11 +-------------------------------------------------- +rank 1. com: 0.012774643488228321 +rank 2. beach: 0.006349315866827965 +rank 3. palm: 0.006187785882502794 +rank 4. daily: 0.0060830800794065 +rank 5. question: 0.0057131643407046795 +rank 6. american: 0.005116782616823912 +rank 7. statesman: 0.0048037427477538586 +rank 8. information: 0.004445109516382217 +rank 9. newspaper: 0.0036021184641867876 +rank 10. holiday: 0.0035788556560873985 +================================================== +topic 12 +-------------------------------------------------- +rank 1. film: 0.006853350438177586 +rank 2. million: 0.003383357310667634 +rank 3. movie: 0.002598291728645563 +rank 4. show: 0.0025286595337092876 +rank 5. zzz_new_york: 0.0020529397297650576 +rank 6. home: 0.0020189047791063786 +rank 7. official: 0.001957985572516918 +rank 8. president: 0.0018404715228825808 +rank 9. company: 0.0018034027889370918 +rank 10. look: 0.0017407435225322843 +================================================== +topic 13 +-------------------------------------------------- +rank 1. run: 0.011265823617577553 +rank 2. inning: 0.008100989274680614 +rank 3. hit: 0.007574521936476231 +rank 4. season: 0.007538564968854189 +rank 5. team: 0.006613561417907476 +rank 6. yankees: 0.00625609653070569 +rank 7. game: 0.006213172804564238 +rank 8. manager: 0.005222634878009558 +rank 9. home: 0.005023866891860962 +rank 10. baseball: 0.00478346599265933 +================================================== +topic 14 +-------------------------------------------------- +rank 1. case: 0.005160974804311991 +rank 2. bin: 0.004822859540581703 +rank 3. police: 0.004709980916231871 +rank 4. laden: 0.0045293704606592655 +rank 5. court: 0.004371668212115765 +rank 6. lawyer: 0.00434792460873723 +rank 7. death: 0.004024742636829615 +rank 8. official: 0.003718850202858448 +rank 9. terrorist: 0.003696390660479665 +rank 10. prosecutor: 0.0036732915323227644 +================================================== +topic 15 +-------------------------------------------------- +rank 1. book: 0.0033520206343382597 +rank 2. show: 0.002301784697920084 +rank 3. million: 0.002270190278068185 +rank 4. women: 0.002245655283331871 +rank 5. group: 0.002187345875427127 +rank 6. home: 0.0019888021051883698 +rank 7. high: 0.00198025512509048 +rank 8. children: 0.0019427805673331022 +rank 9. president: 0.0017906598513945937 +rank 10. word: 0.0017840730724856257 +================================================== +topic 16 +-------------------------------------------------- +rank 1. team: 0.01496296189725399 +rank 2. game: 0.01422667596489191 +rank 3. season: 0.012049943208694458 +rank 4. play: 0.009543812833726406 +rank 5. player: 0.009395435452461243 +rank 6. point: 0.00872756727039814 +rank 7. games: 0.00866071879863739 +rank 8. coach: 0.005120040383189917 +rank 9. win: 0.004865164402872324 +rank 10. shot: 0.004554057028144598 +================================================== +topic 17 +-------------------------------------------------- +rank 1. show: 0.0024334085173904896 +rank 2. school: 0.002219858579337597 +rank 3. boy: 0.0020450616721063852 +rank 4. ago: 0.001936619752086699 +rank 5. women: 0.0019191682804375887 +rank 6. family: 0.0019174017943441868 +rank 7. father: 0.0019156094640493393 +rank 8. official: 0.00186788325663656 +rank 9. look: 0.001812569797039032 +rank 10. part: 0.0017829277785494924 +================================================== +topic 18 +-------------------------------------------------- +rank 1. cup: 0.007968481630086899 +rank 2. food: 0.005801460240036249 +rank 3. minutes: 0.005306020844727755 +rank 4. water: 0.005087833385914564 +rank 5. add: 0.004999660421162844 +rank 6. oil: 0.004342195112258196 +rank 7. tablespoon: 0.004145510494709015 +rank 8. teaspoon: 0.003376882988959551 +rank 9. pepper: 0.003355829743668437 +rank 10. sugar: 0.003274987917393446 +================================================== +topic 19 +-------------------------------------------------- +rank 1. music: 0.005913920234888792 +rank 2. book: 0.004398913588374853 +rank 3. song: 0.003664076328277588 +rank 4. show: 0.002688957843929529 +rank 5. film: 0.0024080141447484493 +rank 6. play: 0.0022673725616186857 +rank 7. album: 0.002084113657474518 +rank 8. band: 0.0019060482736676931 +rank 9. character: 0.0019055294105783105 +rank 10. musical: 0.0018600845942273736 +================================================== +topic 20 +-------------------------------------------------- +rank 1. percent: 0.0024182084016501904 +rank 2. political: 0.0023989675100892782 +rank 3. president: 0.002392817521467805 +rank 4. right: 0.0022518527694046497 +rank 5. zzz_clinton: 0.002106384839862585 +rank 6. zzz_united_states: 0.002106096362695098 +rank 7. government: 0.0020546542946249247 +rank 8. part: 0.0019975078757852316 +rank 9. american: 0.0019730974454432726 +rank 10. million: 0.0019538013730198145 +================================================== +topic 21 +-------------------------------------------------- +rank 1. percent: 0.004818295128643513 +rank 2. company: 0.00477081723511219 +rank 3. million: 0.003893114859238267 +rank 4. show: 0.0033499926794320345 +rank 5. market: 0.003019515657797456 +rank 6. fund: 0.0029496813658624887 +rank 7. stock: 0.002816412365064025 +rank 8. companies: 0.002782258205115795 +rank 9. com: 0.0027674154844135046 +rank 10. money: 0.002466329839080572 +================================================== +topic 22 +-------------------------------------------------- +rank 1. team: 0.010046757757663727 +rank 2. player: 0.006969128269702196 +rank 3. play: 0.006642984692007303 +rank 4. season: 0.006582103203982115 +rank 5. game: 0.006233640015125275 +rank 6. yard: 0.005785048473626375 +rank 7. coach: 0.005579683464020491 +rank 8. football: 0.005469046533107758 +rank 9. quarterback: 0.004959654062986374 +rank 10. zzz_nfl: 0.004103266634047031 +================================================== +topic 23 +-------------------------------------------------- +rank 1. look: 0.004422602243721485 +rank 2. show: 0.003441872540861368 +rank 3. film: 0.0031025675125420094 +rank 4. movie: 0.002993824426084757 +rank 5. women: 0.00254439958371222 +rank 6. designer: 0.002364467130973935 +rank 7. art: 0.002262679161503911 +rank 8. black: 0.0022530003916472197 +rank 9. fashion: 0.0021751534659415483 +rank 10. friend: 0.002020617015659809 +================================================== +topic 24 +-------------------------------------------------- +rank 1. percent: 0.003318313742056489 +rank 2. market: 0.003031886648386717 +rank 3. los: 0.0021878154948353767 +rank 4. home: 0.002176032168790698 +rank 5. fax: 0.0020547923631966114 +rank 6. need: 0.0019452706910669804 +rank 7. las: 0.0019168389262631536 +rank 8. children: 0.0018388490425422788 +rank 9. million: 0.0017693588742986321 +rank 10. point: 0.001726859831251204 +================================================== +topic 25 +-------------------------------------------------- +rank 1. president: 0.004011120647192001 +rank 2. zzz_michael_bloomberg: 0.003995958250015974 +rank 3. children: 0.003151001175865531 +rank 4. court: 0.002441431861370802 +rank 5. million: 0.0024222510401159525 +rank 6. family: 0.0022462978959083557 +rank 7. right: 0.0022452594712376595 +rank 8. case: 0.0021755550988018513 +rank 9. member: 0.002173624001443386 +rank 10. child: 0.0021611815318465233 +================================================== +topic 26 +-------------------------------------------------- +rank 1. student: 0.012113531120121479 +rank 2. school: 0.005448043812066317 +rank 3. zzz_al_gore: 0.0027674699667841196 +rank 4. president: 0.0027605986688286066 +rank 5. program: 0.0027198668103665113 +rank 6. zzz_bush: 0.00268101179972291 +rank 7. official: 0.002358316443860531 +rank 8. million: 0.0023511676117777824 +rank 9. college: 0.00226805848069489 +rank 10. high: 0.002222060691565275 +================================================== +topic 27 +-------------------------------------------------- +rank 1. show: 0.01301606371998787 +rank 2. network: 0.004612243268638849 +rank 3. film: 0.003694617422297597 +rank 4. television: 0.003542348276823759 +rank 5. zzz_nbc: 0.003481118241325021 +rank 6. series: 0.0031934629660099745 +rank 7. viewer: 0.002787014702335 +rank 8. movie: 0.0027608510572463274 +rank 9. zzz_abc: 0.002647952176630497 +rank 10. zzz_cb: 0.0026217657141387463 +================================================== +topic 28 +-------------------------------------------------- +rank 1. fight: 0.007432610262185335 +rank 2. race: 0.0052270242013037205 +rank 3. won: 0.00500878831371665 +rank 4. horse: 0.0040362621657550335 +rank 5. win: 0.0035624431911855936 +rank 6. horses: 0.003421240486204624 +rank 7. winner: 0.0032423168886452913 +rank 8. million: 0.0032216913532465696 +rank 9. zzz_kentucky_derby: 0.003024221630766988 +rank 10. trainer: 0.0029377234168350697 +================================================== +topic 29 +-------------------------------------------------- +rank 1. building: 0.005186538677662611 +rank 2. flight: 0.004052834585309029 +rank 3. million: 0.003262787824496627 +rank 4. space: 0.0030079486314207315 +rank 5. station: 0.002787536708638072 +rank 6. official: 0.0026234739925712347 +rank 7. plan: 0.0026129535399377346 +rank 8. project: 0.0025549777783453465 +rank 9. airport: 0.00252013118006289 +rank 10. miles: 0.0024732924066483974 +================================================== +topic 30 +-------------------------------------------------- +rank 1. official: 0.007196392398327589 +rank 2. government: 0.006008050870150328 +rank 3. palestinian: 0.005985938478261232 +rank 4. attack: 0.005630761384963989 +rank 5. zzz_bush: 0.005420268513262272 +rank 6. military: 0.005230060312896967 +rank 7. zzz_united_states: 0.005217625759541988 +rank 8. leader: 0.00501753855496645 +rank 9. war: 0.004987193737179041 +rank 10. zzz_afghanistan: 0.004758686758577824 +================================================== +topic 31 +-------------------------------------------------- +rank 1. air: 0.0029095665086060762 +rank 2. bird: 0.0020796286407858133 +rank 3. home: 0.0020551159977912903 +rank 4. right: 0.0019946428947150707 +rank 5. night: 0.0019745323807001114 +rank 6. found: 0.0019662429112941027 +rank 7. women: 0.0019117469200864434 +rank 8. small: 0.0018042228184640408 +rank 9. film: 0.0018023115117102861 +rank 10. weather: 0.001782107399776578 +================================================== +topic 32 +-------------------------------------------------- +rank 1. home: 0.003720843931660056 +rank 2. right: 0.003542491467669606 +rank 3. family: 0.0031214465852826834 +rank 4. school: 0.00276115071028471 +rank 5. women: 0.002711761509999633 +rank 6. night: 0.0023221999872475863 +rank 7. play: 0.002255885163322091 +rank 8. left: 0.0022035983856767416 +rank 9. team: 0.002107130130752921 +rank 10. father: 0.0020948820747435093 +================================================== +topic 33 +-------------------------------------------------- +rank 1. school: 0.003237304277718067 +rank 2. president: 0.0027905527967959642 +rank 3. public: 0.002769605955109 +rank 4. family: 0.002466334495693445 +rank 5. play: 0.0022557631600648165 +rank 6. friend: 0.002244712086394429 +rank 7. called: 0.0021952360402792692 +rank 8. show: 0.0021873449441045523 +rank 9. children: 0.0019429487874731421 +rank 10. right: 0.0019146130653098226 +================================================== +topic 34 +-------------------------------------------------- +rank 1. zzz_tiger_wood: 0.005789353512227535 +rank 2. tour: 0.004391436465084553 +rank 3. par: 0.004167190752923489 +rank 4. golf: 0.004120441619306803 +rank 5. player: 0.003632240230217576 +rank 6. round: 0.003603136632591486 +rank 7. course: 0.0034811885561794043 +rank 8. million: 0.0032766603399068117 +rank 9. right: 0.003120653098449111 +rank 10. tournament: 0.003016760805621743 +================================================== +topic 35 +-------------------------------------------------- +rank 1. law: 0.005112105049192905 +rank 2. court: 0.004475763533264399 +rank 3. case: 0.004344474989920855 +rank 4. officer: 0.003633111482486129 +rank 5. official: 0.003590918844565749 +rank 6. priest: 0.0029575068037956953 +rank 7. federal: 0.002699311124160886 +rank 8. lawyer: 0.002647568704560399 +rank 9. right: 0.0025485807564109564 +rank 10. public: 0.002452900167554617 +================================================== +topic 36 +-------------------------------------------------- +rank 1. school: 0.020861370489001274 +rank 2. student: 0.010099130682647228 +rank 3. teacher: 0.005499639548361301 +rank 4. percent: 0.004801126196980476 +rank 5. program: 0.004208665806800127 +rank 6. company: 0.0036289861891418695 +rank 7. million: 0.0030905508901923895 +rank 8. high: 0.003034914145246148 +rank 9. system: 0.0026615734677761793 +rank 10. group: 0.002539578592404723 +================================================== +topic 37 +-------------------------------------------------- +rank 1. film: 0.004587217234075069 +rank 2. plane: 0.003981141839176416 +rank 3. movie: 0.0031528673134744167 +rank 4. computer: 0.0028157581109553576 +rank 5. hour: 0.0026088894810527563 +rank 6. pilot: 0.0026044598780572414 +rank 7. company: 0.0025764324236661196 +rank 8. site: 0.002284643007442355 +rank 9. look: 0.0021962597966194153 +rank 10. business: 0.0021854061633348465 +================================================== +topic 38 +-------------------------------------------------- +rank 1. zzz_george_bush: 0.011270474642515182 +rank 2. zzz_bush: 0.008161350153386593 +rank 3. drug: 0.005321393720805645 +rank 4. president: 0.00409182533621788 +rank 5. plan: 0.00371949071995914 +rank 6. government: 0.003578079864382744 +rank 7. campaign: 0.00345935788936913 +rank 8. political: 0.0033851286862045527 +rank 9. administration: 0.0032937484793365 +rank 10. official: 0.0030149330850690603 +================================================== +topic 39 +-------------------------------------------------- +rank 1. percent: 0.0035957021173089743 +rank 2. million: 0.0032125902362167835 +rank 3. campaign: 0.002461759140715003 +rank 4. official: 0.0022447840310633183 +rank 5. need: 0.0022418724838644266 +rank 6. ago: 0.0022254257928580046 +rank 7. school: 0.0022246893495321274 +rank 8. public: 0.0021837460808455944 +rank 9. group: 0.0020542009733617306 +rank 10. election: 0.0019751274958252907 +================================================== +topic 40 +-------------------------------------------------- +rank 1. car: 0.010070767253637314 +rank 2. company: 0.0029257673304528 +rank 3. model: 0.0026363697834312916 +rank 4. vehicles: 0.002321602776646614 +rank 5. system: 0.0022617375943809748 +rank 6. truck: 0.0022347569465637207 +rank 7. light: 0.0022089716512709856 +rank 8. look: 0.0021158005110919476 +rank 9. vehicle: 0.0021055899560451508 +rank 10. wheel: 0.0020712334662675858 +================================================== +topic 41 +-------------------------------------------------- +rank 1. web: 0.006063418462872505 +rank 2. computer: 0.005808345973491669 +rank 3. www: 0.00550216156989336 +rank 4. com: 0.005284811370074749 +rank 5. mail: 0.004705057479441166 +rank 6. site: 0.004636435769498348 +rank 7. program: 0.003150786040350795 +rank 8. information: 0.0031408770009875298 +rank 9. home: 0.002967627253383398 +rank 10. user: 0.0028863276820629835 +================================================== +topic 42 +-------------------------------------------------- +rank 1. com: 0.005292237736284733 +rank 2. official: 0.00293516693636775 +rank 3. need: 0.002405548235401511 +rank 4. mail: 0.0023109321482479572 +rank 5. right: 0.0023065898567438126 +rank 6. home: 0.002236895263195038 +rank 7. team: 0.002150715794414282 +rank 8. run: 0.001997484592720866 +rank 9. school: 0.0019720534328371286 +rank 10. question: 0.0019104363163933158 +================================================== +topic 43 +-------------------------------------------------- +rank 1. race: 0.00657246820628643 +rank 2. racing: 0.004118914250284433 +rank 3. right: 0.0029221170116215944 +rank 4. medal: 0.0028882403858006 +rank 5. track: 0.002390378387644887 +rank 6. event: 0.0023623635061085224 +rank 7. team: 0.002281986875459552 +rank 8. car: 0.0022783183958381414 +rank 9. won: 0.0022399136796593666 +rank 10. zzz_olympic: 0.002192797837778926 +================================================== +topic 44 +-------------------------------------------------- +rank 1. percent: 0.017496131360530853 +rank 2. stock: 0.010200390592217445 +rank 3. million: 0.009964891709387302 +rank 4. market: 0.007223550695925951 +rank 5. company: 0.005689692217856646 +rank 6. economy: 0.005342676304280758 +rank 7. investor: 0.004899188876152039 +rank 8. companies: 0.0047439876943826675 +rank 9. quarter: 0.004699032753705978 +rank 10. fund: 0.0044333296827971935 +================================================== +topic 45 +-------------------------------------------------- +rank 1. con: 0.006490298546850681 +rank 2. book: 0.006109805777668953 +rank 3. una: 0.004393845796585083 +rank 4. car: 0.0037981965579092503 +rank 5. race: 0.0036588881630450487 +rank 6. driver: 0.003497923957183957 +rank 7. mas: 0.0033254865556955338 +rank 8. por: 0.0031546419486403465 +rank 9. dice: 0.003066850360482931 +rank 10. como: 0.0023456900380551815 +================================================== +topic 46 +-------------------------------------------------- +rank 1. tax: 0.014006761834025383 +rank 2. billion: 0.006483058910816908 +rank 3. percent: 0.005937786307185888 +rank 4. plan: 0.005340836010873318 +rank 5. million: 0.005249205976724625 +rank 6. cut: 0.00519448472186923 +rank 7. money: 0.004712596535682678 +rank 8. market: 0.004287587013095617 +rank 9. taxes: 0.004086908418685198 +rank 10. companies: 0.004031222313642502 +================================================== +topic 47 +-------------------------------------------------- +rank 1. room: 0.0036632674746215343 +rank 2. wine: 0.003140499349683523 +rank 3. home: 0.0024764847476035357 +rank 4. family: 0.0023022012319415808 +rank 5. house: 0.0022613140754401684 +rank 6. small: 0.0020653954707086086 +rank 7. restaurant: 0.0020270454697310925 +rank 8. town: 0.00199923780746758 +rank 9. food: 0.001922571798786521 +rank 10. wines: 0.0019065839005634189 +================================================== +topic 48 +-------------------------------------------------- +rank 1. zzz_microsoft: 0.0044301655143499374 +rank 2. court: 0.004337427671998739 +rank 3. case: 0.0033778140787035227 +rank 4. system: 0.0029870739672333 +rank 5. zzz_pete_sampras: 0.002648326102644205 +rank 6. set: 0.0023190395440906286 +rank 7. right: 0.002282836241647601 +rank 8. team: 0.002235522260889411 +rank 9. law: 0.002207053592428565 +rank 10. company: 0.002188159618526697 +================================================== +topic 49 +-------------------------------------------------- +rank 1. right: 0.0032724656630307436 +rank 2. million: 0.0028670921456068754 +rank 3. part: 0.002584064844995737 +rank 4. group: 0.0022626426070928574 +rank 5. car: 0.0022204353008419275 +rank 6. official: 0.002139380434527993 +rank 7. law: 0.0021126968786120415 +rank 8. family: 0.0019641611725091934 +rank 9. school: 0.0019462114432826638 +rank 10. help: 0.0019059423357248306 +================================================== +topic 50 +-------------------------------------------------- +rank 1. zzz_bush: 0.00572725897654891 +rank 2. zzz_george_bush: 0.0036787609569728374 +rank 3. president: 0.0033892381470650434 +rank 4. percent: 0.003073032945394516 +rank 5. file: 0.0030222907662391663 +rank 6. campaign: 0.0030060207936912775 +rank 7. official: 0.0028273393400013447 +rank 8. plan: 0.0027233336586505175 +rank 9. court: 0.0026171435602009296 +rank 10. decision: 0.0025652681943029165 diff --git a/examples/requirements.txt b/examples/requirements.txt index 728d5c2..ed7d08e 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,2 +1,7 @@ fire -gensim +gensim==3.8.3 +nltk +tqdm +wget +pandas +tabulate diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ca606fb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=1.3.2", + "numpy", + "pybind11" +] diff --git a/requirements.txt b/requirements.txt index bfe001f..537bb27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ +h5py jsmin numpy -pandas +scipy pybind11 protobuf==3.10.0 grpcio-tools==1.27.1