diff options
Diffstat (limited to 'src/models/transformer.h')
-rw-r--r-- | src/models/transformer.h | 69 |
1 files changed, 35 insertions, 34 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h index 7ec40dc5..af877600 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -5,6 +5,7 @@ #include "marian.h" +#include "common/hash.h" #include "layers/constructors.h" #include "models/decoder.h" #include "models/encoder.h" @@ -28,7 +29,7 @@ class Transformer : public EncoderOrDecoderBase { protected: using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_; - std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again + std::unordered_map<std::pair<std::string, Shape>, Expr> cache_; // caching transformation of the encoder that should not be created again mutable/*lazy*/ std::vector<float> sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings bool depthScaling_{false}; // As recommended in the GPT-2 paper, down-scale layer weights by a factor of 1 / sqrt(depth); @@ -40,16 +41,16 @@ protected: std::vector<Expr> alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1] // @TODO: make this go away - template <typename T> - T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); } + template <typename T> + T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); } - template <typename T> - T opt(const std::string& key) const { return opt<T>(key.c_str()); } + template <typename T> + T opt(const std::string& key) const { return opt<T>(key.c_str()); } - template <typename T> + template <typename T> T opt(const char* const key, const T& def) const { Ptr<Options> options = options_; return options->get<T>(key, def); } - template <typename T> + template <typename T> T opt(const std::string& key, const T& def) const { opt<T>(key.c_str(), def); } public: @@ -256,7 +257,7 @@ public: // take softmax along src sequence axis (-1) auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length] - + if(saveAttentionWeights) collectOneHead(weights, dimBeam); @@ -289,26 +290,26 @@ public: // Caching transformation of the encoder that should not be created again. // @TODO: set this automatically by memoizing encoder context and // memoization propagation (short-term) - if (cache // if caching - && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen - && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change - kh = cache_[prefix + "_keys"]; // then return cached tensor - } - else { + std::pair<std::unordered_map<std::pair<std::string, Shape>, Expr>::iterator, bool> cache_result; + if (cache + && !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_keys", keys->shape()}, kh))).second) + ) { + kh = cache_result.first->second; + } else { int dimKeys = keys->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation auto Wk = graph_->param(prefix + "_Wk", {dimKeys, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros()); kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] - cache_[prefix + "_keys"] = kh; + if (cache) cache_result.first->second = kh; } Expr vh; - if (cache - && cache_.count(prefix + "_values") > 0 - && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) { - vh = cache_[prefix + "_values"]; + if (cache + && !((cache_result = cache_.insert(std::pair<std::pair<std::string, Shape>, Expr>({prefix + "_values", values->shape()}, vh))).second) + ) { + vh = cache_result.first->second; } else { int dimValues = values->shape()[-1]; // different than dimModel when using lemma and factors combined with concatenation auto Wv = graph_->param(prefix + "_Wv", {dimValues, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); @@ -316,7 +317,7 @@ public: vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim] vh = SplitHeads(vh, dimHeads); - cache_[prefix + "_values"] = vh; + if (cache) cache_result.first->second = vh; } int dimBeam = q->shape()[-4]; @@ -377,7 +378,7 @@ public: // multi-head self-attention over previous input output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights); - + auto opsPost = opt<std::string>("transformer-postprocess"); output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb); @@ -558,7 +559,7 @@ public: auto embeddingLayer = getEmbeddingLayer(opt<bool>("ulr", false)); std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]); batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch); - + // reorganize batch and timestep batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim] batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1] @@ -593,7 +594,7 @@ public: } // this allows to run a final layernorm operation after going through the transformer layer stack. - // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) + // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) // it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested. auto opsTop = opt<std::string>("transformer-postprocess-top", ""); layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb); @@ -622,14 +623,14 @@ public: int beamSize) const override { // @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState? - + std::vector<Ptr<EncoderState>> newEncStates; - for(auto& es : encStates_) - // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries + for(auto& es : encStates_) + // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // Create hypothesis-selected state based on current state and hyp indices - auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_); + auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_); // Set the same target token position as the current state // @TODO: This is the same as in base function. @@ -763,8 +764,8 @@ public: // This would happen if something goes wrong during batch pruning. ABORT_IF(encoderContext->shape()[-3] != dimBatch, - "Context and query batch dimension do not match {} != {}", - encoderContext->shape()[-3], + "Context and query batch dimension do not match {} != {}", + encoderContext->shape()[-3], dimBatch); // LayerAttention expects mask in a different layout @@ -871,7 +872,7 @@ public: } // This allows to run a final layernorm operation after going through the transformer layer stack. - // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) + // By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da) // it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested. auto opsTop = opt<std::string>("transformer-postprocess-top", ""); query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb); @@ -884,7 +885,7 @@ public: if(shortlist_) output_->setShortlist(shortlist_); auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim] - + // return unormalized(!) probabilities Ptr<DecoderState> nextState; if (opt<std::string>("transformer-decoder-autoreg", "self-attention") == "rnn") { @@ -909,9 +910,9 @@ public: output_->clear(); cache_.clear(); alignments_.clear(); - perLayerRnn_.clear(); // this needs to be cleared between batches. - // @TODO: figure out how to detect stale nodes i.e. nodes that are referenced, - // but where underlying memory has been deallocated by dropping all tensors + perLayerRnn_.clear(); // this needs to be cleared between batches. + // @TODO: figure out how to detect stale nodes i.e. nodes that are referenced, + // but where underlying memory has been deallocated by dropping all tensors // from a TensorAllocator object. This can happen during ExpressionGraph::clear() } }; |