Skip to content

Commit

Permalink
Merge pull request #6 from js1010/task/add-benchmark
Browse files Browse the repository at this point in the history
Task/add benchmark
  • Loading branch information
js1010 authored Feb 15, 2021
2 parents ec13e7c + 5898669 commit 469d0ff
Show file tree
Hide file tree
Showing 28 changed files with 2,362 additions and 289 deletions.
14 changes: 14 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
include cuda_setup.py
include requirements.txt
include pyproject.toml
recursive-include cpp/src/cuw2v/ *.cu
recursive-include cpp/src/culda/ *.cu
recursive-include cpp/src/ioutils/ *.cc
recursive-include cpp/include/cuw2v/ *.cuh
recursive-include cpp/include/cuw2v/ *.hpp
recursive-include cpp/include/culda/ *.cuh
recursive-include cpp/include/culda/ *.hpp
recursive-include cpp/include/ioutils/ *.cuh
recursive-include cpp/include/ioutils/ *.hpp
recursive-include 3rd/json11/ *
recursive-include 3rd/spdlog/ *
65 changes: 65 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
### Introduction

This project is to speed up various ML models (e.g. topic modeling, word embedding, etc) by CUDA. It would be nice to think of it as [gensim](https://github.com/RaRe-Technologies/gensim)'s GPU version project. As a starting step, I implemented the most widely used word embedding model, the [word2vec](https://arxiv.org/pdf/1301.3781.pdf) model, and the most representative topic model, the [LDA (Latent Dirichlet Allocation)](https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf) model.

### How to install

- install from source

```shell
# clone repo and submodules
Expand All @@ -14,3 +19,63 @@ python -m grpc_tools.protoc --python_out cusim/ --proto_path cusim/proto/ config
# install
python setup.py install
```

- pip installation will be available soon

### How to use

- `examples/example_w2v.py`, `examples/example_lda.py` and `examples/README.md` will be very helpful to understand the usage.
- paremeter description can be seen in `cusim/proto/config.proto`

### Performance

- [AWS g4dn 2xlarge instance](https://aws.amazon.com/ec2/instance-types/g4/) is used to the experiment. (One NVIDIA T4 GPU with 8 vcpus, Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz)
- results can be reproduced by simply running `examples/example_w2v.py` and `examples/example_lda.py`
- To evaluate w2v model, I used `evaluate_word_pairs` function ([ref link](https://radimrehurek.com/gensim/auto_examples/tutorials/run_word2vec.html#evaluating)) in gensim, note that better performance on WS-353 test set does not necessarily mean that the model will workbetter in application as desribed on the link. However, it is good to be measured quantitively and fast training time will be at least very objective measure of the performaance.
- I trained W2V model on `quora-duplicat-questions` dataset from gensim downloader api on GPU with cusim and compare the performance (both speed and model quality) with gensim.
- To evaluate LDA model, I found there is no good way to measure the quality of traing results quantitatively. But we can check the model by looking at the top words of each topic. Also, we can compare the training time quantitatively.
- W2V (skip gram, hierarchical softmax)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 892.596 | 544.212 | 310.727 | 226.472 | **16.162** |
| pearson | 0.487832 | 0.487696 | 0.482821 | 0.487136 | **0.492101** |
| spearman | 0.500846 | 0.506214 | 0.501048 | **0.506718** | 0.479468 |

- W2V (skip gram, negative sampling)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 586.545 | 340.489 | 220.804 | 146.23 | **33.9173** |
| pearson | 0.354448 | 0.353952 | 0.352398 | 0.352925 | **0.360436** |
| spearman | 0.369146 | 0.369365 | **0.370565** | 0.365822 | 0.355204 |

- W2V (CBOW, hierarchical softmax)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 250.135 | 155.121 | 103.57 | 73.8073 | **6.20787** |
| pearson | 0.309651 | 0.321803 | 0.324854 | 0.314255 | **0.480298** |
| spearman | 0.294047 | 0.308723 | 0.318293 | 0.300591 | **0.480971** |

- W2V (CBOW, negative sampling)

| attr | 1 workers (gensim) | 2 workers (gensim) | 4 workers (gensim) | 8 workers (gensim) | NVIDIA T4 (cusim) |
|:--------------------|---------------------:|---------------------:|---------------------:|---------------------:|--------------------:|
| training time (sec) | 176.923 | 100.369 | 69.7829 | 49.9274 | **9.90391** |
| pearson | 0.18772 | 0.193152 | 0.204509 | 0.187924 | **0.368202** |
| spearman | 0.243975 | 0.24587 | 0.260531 | 0.237441 | **0.358042** |

- LDA (`nytimes` dataset from https://archive.ics.uci.edu/ml/datasets/bag+of+words)
- I found that setting `workers` variable in gensim LdaMulticore does not work properly (it uses all cores in instance anyway), so I just compared the speed between cusim with single GPU and gensim with 8 vcpus.
- One can compare the quality of modeling by looking at `examples/cusim.topics.txt` and `examples/gensim.topics.txt`.

| attr | gensim (8 vpus) | cusim (NVIDIA T4)|
|:--------------------|------------------:|--------:|
| training time (sec) | 447.376 | **76.6972** |

### Future tasks

- support half precision
- support multi device (multi device implementation on LDA model will not be that hard, while multi device training on w2v may require some considerations)
- implement other models such as FastText, BERT, etc
107 changes: 73 additions & 34 deletions cpp/include/culda/cuda_lda_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,36 @@ float Digamma(float x) {
}

__global__ void EstepKernel(
const int* cols, const int* indptr, const bool* vali,
const int num_cols, const int num_indptr,
const int* cols, const int* indptr,
const bool* vali, const float* counts,
const bool init_gamma, const int num_cols, const int num_indptr,
const int num_topics, const int num_iters,
float* gamma, float* new_gamma, float* phi,
const float* alpha, const float* beta,
float* grad_alpha, float* new_beta,
float* train_losses, float* vali_losses, int* mutex) {
float* gamma, float* grad_alpha, float* new_beta,
float* train_losses, float* vali_losses, int* locks) {

// storage for block
float* _gamma = gamma + num_topics * blockIdx.x;
float* _new_gamma = new_gamma + num_topics * blockIdx.x;
float* _phi = phi + num_topics * blockIdx.x;
extern __shared__ float shared_memory[];
float* _new_gamma = &shared_memory[0];
float* _phi = &shared_memory[num_topics];
float* _loss_vec = &shared_memory[num_topics * 2];
float* _vali_phi_sum = &shared_memory[num_topics * 3];

float* _grad_alpha = grad_alpha + num_topics * blockIdx.x;

for (int i = blockIdx.x; i < num_indptr; i += gridDim.x) {
int beg = indptr[i], end = indptr[i + 1];
// initialize gamma
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_gamma[j] = alpha[j] + (end - beg) / num_topics;
float* _gamma = gamma + num_topics * i;
if (init_gamma) {
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
_gamma[j] = alpha[j] + (end - beg) / num_topics;
}
}
__syncthreads();

// initiate phi sum for validation data for computing vali loss
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_vali_phi_sum[j] = 0.0f;

// iterate E step
for (int j = 0; j < num_iters; ++j) {
Expand All @@ -58,7 +68,7 @@ __global__ void EstepKernel(
for (int k = beg; k < end; ++k) {
const int w = cols[k];
const bool _vali = vali[k];

const float c = counts[k];
// compute phi
if (not _vali or j + 1 == num_iters) {
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
Expand All @@ -70,37 +80,52 @@ __global__ void EstepKernel(

for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
_phi[l] /= phi_sum;
if (not _vali) _new_gamma[l] += _phi[l];

// update gamma for train data and phi_sum for computing loss
if (_vali)
_vali_phi_sum[l] += _phi[l] * c;
else
_new_gamma[l] += _phi[l] * c;

}
__syncthreads();
}

if (j + 1 == num_iters) {
// write access of w th vector of new_beta
if (threadIdx.x == 0) {
while (atomicCAS(&mutex[w], 0, 1)) {}
}
// update beta for train data
if (not _vali) {
// write access of w th vector of new_beta
if (threadIdx.x == 0) {
while (atomicCAS(&locks[w], 0, 1)) {}
}

__syncthreads();
__syncthreads();
for (int l = threadIdx.x; l < num_topics; l += blockDim.x)
new_beta[w * num_topics + l] += _phi[l] * c;
__syncthreads();

// release lock
if (threadIdx.x == 0) locks[w] = 0;
__syncthreads();
}

// comput loss and reset shared mem
// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
for (int l = threadIdx.x; l < num_topics; l += blockDim.x) {
if (j + 1 == num_iters) {
if (not _vali) new_beta[w * num_topics + l] += _phi[l];
_phi[l] *= beta[w * num_topics + l];
}
_loss_vec[l] = logf(fmaxf(beta[w * num_topics + l], EPS));
_loss_vec[l] -= logf(fmaxf(_phi[l], EPS));
_loss_vec[l] *= _phi[l];
}
__syncthreads();

// release lock
if (threadIdx.x == 0) mutex[w] = 0;
__syncthreads();

float p = fmaxf(EPS, ReduceSum(_phi, num_topics));
float _loss = ReduceSum(_loss_vec, num_topics) * c;
if (threadIdx.x == 0) {
if (_vali)
vali_losses[blockIdx.x] += logf(p);
if (_vali)
vali_losses[blockIdx.x] += _loss;
else
train_losses[blockIdx.x] += logf(p);
}
train_losses[blockIdx.x] += _loss;
}
__syncthreads();

}
__syncthreads();
}
Expand All @@ -110,9 +135,23 @@ __global__ void EstepKernel(
_gamma[k] = _new_gamma[k] + alpha[k];
__syncthreads();
}

// update gradient of alpha and loss from E[log(theta)]
float gamma_sum = ReduceSum(_gamma, num_topics);
for (int j = threadIdx.x; j < num_topics; j += blockDim.x)
_grad_alpha[j] += (Digamma(_gamma[j]) - Digamma(gamma_sum));
for (int j = threadIdx.x; j < num_topics; j += blockDim.x) {
float Elogthetad = Digamma(_gamma[j]) - Digamma(gamma_sum);
_grad_alpha[j] += Elogthetad;
_new_gamma[j] *= Elogthetad;
_vali_phi_sum[j] *= Elogthetad;
}

// see Eq (15) in https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf
float train_loss = ReduceSum(_new_gamma, num_topics);
float vali_loss = ReduceSum(_vali_phi_sum, num_topics);
if (threadIdx.x == 0) {
train_losses[blockIdx.x] += train_loss;
vali_losses[blockIdx.x] += vali_loss;
}

__syncthreads();
}
Expand Down
10 changes: 6 additions & 4 deletions cpp/include/culda/culda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ class CuLDA {
void LoadModel(float* alpha, float* beta,
float* grad_alpha, float* new_beta, const int num_words);
std::pair<float, float> FeedData(
const int* indices, const int* indptr, const bool* vali,
const int num_indices, const int num_indptr, const int num_iters);
const int* indices, const int* indptr,
const bool* vali, const float* counts,
float* gamma, const bool init_gamma,
const int num_indices, const int num_indptr,
const int num_iters);
void Pull();
void Push();
int GetBlockCnt();
Expand All @@ -78,8 +81,7 @@ class CuLDA {
std::unique_ptr<CuSimLogger> logger_container_;
thrust::device_vector<float> dev_alpha_, dev_beta_;
thrust::device_vector<float> dev_grad_alpha_, dev_new_beta_;
thrust::device_vector<float> dev_gamma_, dev_new_gamma_, dev_phi_;
thrust::device_vector<int> dev_mutex_;
thrust::device_vector<int> dev_locks_;

float *alpha_, *beta_, *grad_alpha_, *new_beta_;
int block_cnt_, block_dim_;
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuw2v/cuda_w2v_base_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
#pragma once
#include "utils/cuda_utils_kernels.cuh"

#define MAX_EXP 20

namespace cusim {


__inline__ __device__
void PositiveFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
if (threadIdx.x == 0) {
float exp_dot = expf(-dot);
g = exp_dot / (1 + exp_dot) * lr;
Expand All @@ -32,7 +34,7 @@ __inline__ __device__
void NegativeFeedback(const float* vec1, float* vec2, float* grad,
float& loss_nume, float& loss_deno, const int num_dims, const float lr) {
static __shared__ float g;
float dot = Dot(vec1, vec2, num_dims);
float dot = fmaxf(-MAX_EXP, fminf(MAX_EXP, Dot(vec1, vec2, num_dims)));
if (threadIdx.x == 0) {
float exp_dot = expf(dot);
g = exp_dot / (1 + exp_dot) * lr;
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/cuw2v/cuda_w2v_hs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ __global__ void W2VHsSgKernel(
__syncthreads();
int beg2 = max(beg, j - window_size + reduced_windows);
int end2 = min(end, j + window_size - reduced_windows + 1);
float* _emb_in = emb_in + num_dims * cols[j];
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
int beg3 = hs_indptr[cols[k]];
int end3 = hs_indptr[cols[k] + 1];
float* _emb_in = emb_in + num_dims * cols[k];
int beg3 = hs_indptr[cols[j]];
int end3 = hs_indptr[cols[j] + 1];
for (int l = beg3; l < end3; ++l) {
if (codes[l]) {
PositiveFeedback(_emb_in, emb_out + num_dims * points[l],
Expand All @@ -55,7 +55,7 @@ __global__ void W2VHsSgKernel(
__syncthreads();
}
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
emb_in[num_dims * cols[j] + l] += grad[l];
_emb_in[l] += grad[l];
grad[l] = 0.0f;
}
__syncthreads();
Expand All @@ -70,7 +70,7 @@ __global__ void W2VHsCbowKernel(
const int num_indptr, const int num_dims, const int window_size, default_random_engine* rngs,
float* emb_in, float* emb_out,
float* loss_nume, float* loss_deno,
const bool use_mean, const float lr) {
const bool cbow_mean, const float lr) {

default_random_engine& rng = rngs[blockIdx.x];
float& _loss_nume = loss_nume[blockIdx.x];
Expand Down Expand Up @@ -98,15 +98,15 @@ __global__ void W2VHsCbowKernel(
grad[k] = 0.0f;
cbow[k] = 0.0f;
}

// compute cbow
for (int k = beg2; k < end2; ++k) {
if (k == j) continue;
for (int l = threadIdx.x; l < num_dims; l += blockDim.x) {
cbow[l] += emb_in[num_dims * cols[k] + l];
}
}
if (use_mean) {
if (cbow_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
cbow[k] /= (end2 - beg2 - 1);
}
Expand All @@ -126,8 +126,8 @@ __global__ void W2VHsCbowKernel(
__syncthreads();
}

// normalize grad if use_mean = true
if (use_mean) {
// normalize grad if cbow_mean = true
if (cbow_mean) {
for (int k = threadIdx.x; k < num_dims; k += blockDim.x) {
grad[k] /= (end2 - beg2 - 1);
}
Expand Down
Loading

0 comments on commit 469d0ff

Please sign in to comment.