Skip to content

Commit

Permalink
Check shape on transformer cache
Browse files Browse the repository at this point in the history
  • Loading branch information
graemenail committed Aug 24, 2022
1 parent e88c1aa commit 33fd8d7
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Transformer : public EncoderOrDecoderBase {

protected:
using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
std::unordered_map<std::string, std::pair<Shape, Expr>> cache_; // caching transformation of the encoder that should not be created again
mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings

// attention weights produced by step()
Expand Down Expand Up @@ -279,32 +279,32 @@ class Transformer : public EncoderOrDecoderBase {
// Caching transformation of the encoder that should not be created again.
// @TODO: set this automatically by memoizing encoder context and
// memoization propagation (short-term)
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
kh = cache_[prefix + "_keys"]; // then return cached tensor
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"].first == keys->shape()) { // and the underlying element size did not change
kh = cache_[prefix + "_keys"].second; // then return cached tensor
}
else {
auto Wk = graph_->param(prefix + "_Wk", {dimModel, dimModel}, inits::glorotUniform());
auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());

kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
cache_[prefix + "_keys"] = kh;
cache_[prefix + "_keys"] = std::make_pair(keys->shape(), kh);
}

Expr vh;
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
vh = cache_[prefix + "_values"];
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"].first == values->shape()) {
vh = cache_[prefix + "_values"].second;
} else {
auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform());
auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros());

vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
vh = SplitHeads(vh, dimHeads);
cache_[prefix + "_values"] = vh;
cache_[prefix + "_values"] = std::make_pair(values->shape(), vh);
}

int dimBeam = q->shape()[-4];
Expand Down

0 comments on commit 33fd8d7

Please sign in to comment.