diff options
author | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-03-18 06:41:24 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-03-18 06:41:24 +0300 |
commit | 272096c1d188dcd0ec33ba349bab5955c497876a (patch) | |
tree | 93adf1d89b1900a3d017acf8b5fc48c6f518ccf5 /src | |
parent | 77c3e356a47113f661dda794b815e84561ca93f5 (diff) | |
parent | 8f73923d3134f4799497b7e880963336b8fe4d6b (diff) |
sync public and internal master
Diffstat (limited to 'src')
82 files changed, 2844 insertions, 1166 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b47663b4..64b86a69 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -40,6 +40,7 @@ set(MARIAN_SOURCES data/corpus_sqlite.cpp data/corpus_nbest.cpp data/text_input.cpp + data/shortlist.cpp 3rd_party/cnpy/cnpy.cpp 3rd_party/ExceptionWithCallStack.cpp @@ -72,6 +73,9 @@ set(MARIAN_SOURCES layers/loss.cpp layers/weight.cpp layers/lsh.cpp + layers/embedding.cpp + layers/output.cpp + layers/logits.cpp rnn/cells.cpp rnn/attention.cpp @@ -84,6 +88,7 @@ set(MARIAN_SOURCES models/model_factory.cpp models/encoder_decoder.cpp models/transformer_stub.cpp + models/costs.cpp rescorer/score_collector.cpp embedder/vector_collector.cpp @@ -103,10 +108,15 @@ set(MARIAN_SOURCES training/validator.cpp training/communicator.cpp - # this is only compiled to catch build errors, but not linked + # this is only compiled to catch build errors microsoft/quicksand.cpp microsoft/cosmos.cpp + # copied from quicksand to be able to read binary shortlist + microsoft/shortlist/utils/Converter.cpp + microsoft/shortlist/utils/StringUtils.cpp + microsoft/shortlist/utils/ParameterTree.cpp + $<TARGET_OBJECTS:libyaml-cpp> $<TARGET_OBJECTS:SQLiteCpp> $<TARGET_OBJECTS:pathie-cpp> diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 3baa13ea..3baa13ea 100755..100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp diff --git a/src/common/definitions.h b/src/common/definitions.h index d2cf8aa4..d2cf8aa4 100755..100644 --- a/src/common/definitions.h +++ b/src/common/definitions.h diff --git a/src/common/file_stream.cpp b/src/common/file_stream.cpp index 78cbb12f..78cbb12f 100755..100644 --- a/src/common/file_stream.cpp +++ b/src/common/file_stream.cpp diff --git a/src/common/io_item.h b/src/common/io_item.h index d86c01ac..d86c01ac 100755..100644 --- a/src/common/io_item.h +++ b/src/common/io_item.h diff --git a/src/common/options.h b/src/common/options.h index 08c6a3ca..08c6a3ca 100755..100644 --- a/src/common/options.h +++ b/src/common/options.h diff --git a/src/common/timer.cpp b/src/common/timer.cpp new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/common/timer.cpp diff --git a/src/common/utils.cpp b/src/common/utils.cpp index 72624041..72624041 100755..100644 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp diff --git a/src/data/batch.h b/src/data/batch.h index 3c592b31..3c592b31 100755..100644 --- a/src/data/batch.h +++ b/src/data/batch.h diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index e8ce850b..e8ce850b 100755..100644 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 5be4298b..5be4298b 100755..100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp index 818f3788..17a5bfb7 100755..100644 --- a/src/data/factored_vocab.cpp +++ b/src/data/factored_vocab.cpp @@ -546,7 +546,6 @@ void FactoredVocab::constructNormalizationInfoForVocab() { /*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { for (; num-- > 0; ptr++) { auto word = Word::fromWordIndex(*ptr); - auto wordString = word2string(word); auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first; *ptr = (WordIndex)lemmaIndex; } diff --git a/src/data/factored_vocab.h b/src/data/factored_vocab.h index 215e92f0..215e92f0 100755..100644 --- a/src/data/factored_vocab.h +++ b/src/data/factored_vocab.h diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp new file mode 100644 index 00000000..6f551262 --- /dev/null +++ b/src/data/shortlist.cpp @@ -0,0 +1,153 @@ +#include "data/shortlist.h" +#include "microsoft/shortlist/utils/ParameterTree.h" + +namespace marian { +namespace data { + +// cast current void pointer to T pointer and move forward by num elements +template <typename T> +const T* get(const void*& current, size_t num = 1) { + const T* ptr = (const T*)current; + current = (const T*)current + num; + return ptr; +} + +QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx, + size_t /*trgIdx*/, + bool /*shared*/) + : options_(options), + srcVocab_(srcVocab), + trgVocab_(trgVocab), + srcIdx_(srcIdx) { + std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist"); + + ABORT_IF(vals.empty(), "No path to filter path given"); + std::string fname = vals[0]; + + auto firstNum = vals.size() > 1 ? std::stoi(vals[1]) : 0; + auto bestNum = vals.size() > 2 ? std::stoi(vals[2]) : 0; + float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0; + + if(firstNum != 0 || bestNum != 0 || threshold != 0) { + LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored."); + } + + mmap_ = mio::mmap_source(fname); // memory-map the binary file once + const void* current = mmap_.data(); // pointer iterator over binary file + + // compare magic number in binary file to make sure we are reading the right thing + const int32_t MAGIC_NUMBER = 1234567890; + int32_t header_magic_number = *get<int32_t>(current); + ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number"); + + auto config = ::quicksand::ParameterTree::FromBinaryReader(current); + use16bit_ = config->GetBoolReq("use_16_bit"); + + LOG(info, "[data] Mapping Quicksand shortlist from {}", fname); + + idSize_ = sizeof(int32_t); + if (use16bit_) { + idSize_ = sizeof(uint16_t); + } + + // mmap the binary shortlist pieces + numDefaultIds_ = *get<int32_t>(current); + defaultIds_ = get<int32_t>(current, numDefaultIds_); + numSourceIds_ = *get<int32_t>(current); + sourceLengths_ = get<int32_t>(current, numSourceIds_); + sourceOffsets_ = get<int32_t>(current, numSourceIds_); + numShortlistIds_ = *get<int32_t>(current); + sourceToShortlistIds_ = get<uint8_t>(current, idSize_ * numShortlistIds_); + + // display parameters + LOG(info, + "[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids", + numSourceIds_, + numDefaultIds_, + numShortlistIds_); +} + +Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const { + auto srcBatch = (*batch)[srcIdx_]; + auto maxShortlistSize = trgVocab_->size(); + + std::unordered_set<int32_t> indexSet; + for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) { + int32_t id = defaultIds_[i]; + indexSet.insert(id); + } + + // State + std::vector<std::pair<const uint8_t*, int32_t>> curShortlists(maxShortlistSize); + auto curShortlistIt = curShortlists.begin(); + + // Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank. + // E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ... + int32_t maxLength = 0; + for (Word word : srcBatch->data()) { + int32_t sourceId = (int32_t)word.toWordIndex(); + srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1); + + if (sourceId < numSourceIds_) { // if it's a valid source id + const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping + int32_t length = sourceLengths_[sourceId]; // how many mappings are there + curShortlistIt->first = curShortlistIds; + curShortlistIt->second = length; + curShortlistIt++; + + if (length > maxLength) + maxLength = length; + } + } + + // collect the actual shortlist mappings + for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) { + for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) { + int32_t length = curShortlists[j].second; + if (i < length) { + const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first; + int32_t id = 0; + if (use16bit_) { + const uint16_t* source_shortlist_ids = reinterpret_cast<const uint16_t*>(source_shortlist_ids_bytes); + id = (int32_t)source_shortlist_ids[i]; + } + else { + const int32_t* source_shortlist_ids = reinterpret_cast<const int32_t*>(source_shortlist_ids_bytes); + id = source_shortlist_ids[i]; + } + indexSet.insert(id); + } + } + } + + // turn into vector and sort (selected indices) + std::vector<WordIndex> indices; + indices.reserve(indexSet.size()); + for(auto i : indexSet) + indices.push_back((WordIndex)i); + + std::sort(indices.begin(), indices.end()); + return New<Shortlist>(indices); +} + +Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx, + size_t trgIdx, + bool shared) { + std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist"); + ABORT_IF(vals.empty(), "No path to shortlist given"); + std::string fname = vals[0]; + if(filesystem::Path(fname).extension().string() == ".bin") { + return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } else { + return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared); + } +} + +} // namespace data +} // namespace marian diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 395bcfee..ab6a087b 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -5,6 +5,7 @@ #include "common/file_stream.h" #include "data/corpus_base.h" #include "data/types.h" +#include "mio/mio.hpp" #include <random> #include <unordered_map> @@ -292,5 +293,51 @@ public: } }; +/* +Legacy binary shortlist for Microsoft-internal use. +*/ +class QuicksandShortlistGenerator : public ShortlistGenerator { +private: + Ptr<Options> options_; + Ptr<const Vocab> srcVocab_; + Ptr<const Vocab> trgVocab_; + + size_t srcIdx_; + + mio::mmap_source mmap_; + + // all the quicksand bits go here + bool use16bit_{false}; + int32_t numDefaultIds_; + int32_t idSize_; + const int32_t* defaultIds_{nullptr}; + int32_t numSourceIds_{0}; + const int32_t* sourceLengths_{nullptr}; + const int32_t* sourceOffsets_{nullptr}; + int32_t numShortlistIds_{0}; + const uint8_t* sourceToShortlistIds_{nullptr}; + +public: + QuicksandShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx = 0, + size_t trgIdx = 1, + bool shared = false); + + virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override; +}; + +/* +Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist +unless the extension is *.bin for which the Microsoft legacy binary shortlist is used. +*/ +Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options, + Ptr<const Vocab> srcVocab, + Ptr<const Vocab> trgVocab, + size_t srcIdx = 0, + size_t trgIdx = 1, + bool shared = false); + } // namespace data } // namespace marian diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp index 07ac479e..07ac479e 100755..100644 --- a/src/data/vocab.cpp +++ b/src/data/vocab.cpp diff --git a/src/data/vocab.h b/src/data/vocab.h index 9a40ba16..9a40ba16 100755..100644 --- a/src/data/vocab.h +++ b/src/data/vocab.h diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h index 8c214c97..8c214c97 100755..100644 --- a/src/data/vocab_base.h +++ b/src/data/vocab_base.h diff --git a/src/functional/operators.h b/src/functional/operators.h index a14f153f..a14f153f 100755..100644 --- a/src/functional/operators.h +++ b/src/functional/operators.h diff --git a/src/functional/shape.h b/src/functional/shape.h index fd354e1e..fd354e1e 100755..100644 --- a/src/functional/shape.h +++ b/src/functional/shape.h diff --git a/src/functional/tensor.h b/src/functional/tensor.h index f5549c60..f5549c60 100755..100644 --- a/src/functional/tensor.h +++ b/src/functional/tensor.h diff --git a/src/functional/tmp.h b/src/functional/tmp.h index a83c0ff4..a83c0ff4 100755..100644 --- a/src/functional/tmp.h +++ b/src/functional/tmp.h diff --git a/src/graph/auto_tuner.h b/src/graph/auto_tuner.h index 01f33085..01f33085 100755..100644 --- a/src/graph/auto_tuner.h +++ b/src/graph/auto_tuner.h diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index ca0739e4..ca0739e4 100755..100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h diff --git a/src/graph/node.cpp b/src/graph/node.cpp index 257a639f..257a639f 100755..100644 --- a/src/graph/node.cpp +++ b/src/graph/node.cpp diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp index 4e39d1bf..4e39d1bf 100755..100644 --- a/src/graph/node_initializers.cpp +++ b/src/graph/node_initializers.cpp diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h index 7cdb4183..7cdb4183 100755..100644 --- a/src/graph/node_initializers.h +++ b/src/graph/node_initializers.h diff --git a/src/layers/constructors.h b/src/layers/constructors.h index a2c38197..9e9de207 100755..100644 --- a/src/layers/constructors.h +++ b/src/layers/constructors.h @@ -1,7 +1,9 @@ #pragma once +#include "layers/embedding.h" #include "layers/factory.h" #include "layers/generic.h" +#include "layers/output.h" namespace marian { namespace mlp { @@ -43,6 +45,7 @@ struct LogitLayerFactory : public Factory { // @TODO: In the long run, I hope we can get rid of the abstract factories altogether. class OutputFactory : public LogitLayerFactory { using LogitLayerFactory::LogitLayerFactory; + protected: std::string tiedTransposedName_; Ptr<data::Shortlist> shortlist_; @@ -53,9 +56,7 @@ public: return Accumulator<OutputFactory>(*this); } - void setShortlist(Ptr<data::Shortlist> shortlist) { - shortlist_ = shortlist; - } + void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; } Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override { auto output = New<Output>(graph, options_); @@ -87,8 +88,7 @@ protected: std::vector<Ptr<IUnaryLayer>> layers_; public: - MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : graph_(graph), options_(options) {} + MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {} Expr apply(const std::vector<Expr>& av) override { Expr output; @@ -104,46 +104,53 @@ public: } Logits applyAsLogits(const std::vector<Expr>& av) override { - // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type + // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different + // return type auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back()); - ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); - if (layers_.size() == 1) { - if (av.size() == 1) + ABORT_IF( + !lastLayer, + "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer"); + if(layers_.size() == 1) { + if(av.size() == 1) return lastLayer->applyAsLogits(av[0]); else return lastLayer->applyAsLogits(av); - } - else { + } else { Expr output; - if (av.size() == 1) + if(av.size() == 1) output = layers_[0]->apply(av[0]); else output = layers_[0]->apply(av); - for (size_t i = 1; i < layers_.size() - 1; ++i) + for(size_t i = 1; i < layers_.size() - 1; ++i) output = layers_[i]->apply(output); return lastLayer->applyAsLogits(output); } } - Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); } - Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); } + Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); } + Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); } void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); } void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); } void setShortlist(Ptr<data::Shortlist> shortlist) override final { auto p = tryAsHasShortlist(); - ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists"); + ABORT_IF( + !p, + "setShortlist() called on an MLP with an output layer that does not support short lists"); p->setShortlist(shortlist); } void clear() override final { auto p = tryAsHasShortlist(); - if (p) + if(p) p->clear(); } + private: - Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); } + Ptr<IHasShortList> tryAsHasShortlist() const { + return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); + } }; /** @@ -152,6 +159,7 @@ private: */ class MLPFactory : public Factory { using Factory::Factory; + private: std::vector<Ptr<LayerFactory>> layers_; @@ -175,23 +183,27 @@ public: // which will go away if we get rid of the abstract factories, and instead just construct // all layers immediately, which is my long-term goal for Marian. private: - template<class WrappedFactory> + template <class WrappedFactory> class AsLayerFactory : public LayerFactory { - WrappedFactory us; + WrappedFactory us; + public: - AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} - Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { - auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); - ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); - return p; - } + AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {} + Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final { + auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph)); + ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one"); + return p; + } }; - template<class WrappedFactory> - static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; } + template <class WrappedFactory> + static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { + return wrapped; + } + public: Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) { push_back(AsLayerFactory<OutputFactory>(lf)); - //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); + // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf))); return Accumulator<MLPFactory>(*this); } }; diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp new file mode 100644 index 00000000..92c4ad6d --- /dev/null +++ b/src/layers/embedding.cpp @@ -0,0 +1,194 @@ +#include "embedding.h" +#include "data/factored_vocab.h" + +namespace marian { + +Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) + : LayerBase(graph, options), inference_(opt<bool>("inference")) { + std::string name = opt<std::string>("prefix"); + int dimVoc = opt<int>("dimVocab"); + int dimEmb = opt<int>("dimEmb"); + + bool fixed = opt<bool>("fixed", false); + + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); + if(factoredVocab_) { + dimVoc = (int)factoredVocab_->factorVocabSize(); + LOG_ONCE(info, "[embedding] Factored embeddings enabled"); + } + + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform( + /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length + + if(options_->has("embFile")) { + std::string file = opt<std::string>("embFile"); + if(!file.empty()) { + bool norm = opt<bool>("normalization", false); + initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); + } + } + + E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); +} + +// helper to embed a sequence of words (given as indices) via factored embeddings +Expr Embedding::multiRows(const Words& data, float dropProb) const { + auto graph = E_->graph(); + auto factoredData = factoredVocab_->csr_rows(data); + // multi-hot factor vectors are represented as a sparse CSR matrix + // [row index = word position index] -> set of factor indices for word at this position + ABORT_IF(factoredData.shape + != Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}), + "shape mismatch??"); + // the CSR matrix is passed in pieces + auto weights = graph->constant({(int)factoredData.weights.size()}, + inits::fromVector(factoredData.weights)); + auto indices = graph->constant( + {(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32); + auto offsets = graph->constant( + {(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32); + // apply dropout + // We apply it to the weights, i.e. factors get dropped out separately, but always as entire + // vectors. + if(!inference_) + weights = dropout(weights, dropProb); + // perform the product + return csr_dot(factoredData.shape, weights, indices, offsets, E_); +} + +std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const +/*override final*/ { + auto graph = E_->graph(); + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = E_->shape()[-1]; + int dimWidth = (int)subBatch->batchWidth(); + + // factored embeddings: + // - regular: + // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] + // - factored: + // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) + // - each row of M contains the set of factors for one word => we want a CSR matrix + // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] + // - first compute x @ M on the CPU + // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): + // - shape (U, specifically) not actually needed here + // - foreach input x[i] + // - locate row M[i,*] + // - copy through its index values (std::vector<push_back>) + // - create a matching ones vector (we can keep growing) + // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) + // - CSR matrix product with E + // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) + // - double-check if all dimensions are specified. Probably not for transpose (which would + // be like csc_dot()). + // - weighting: + // - core factors' gradients are sums over all words that use the factors; + // - core factors' embeddings move very fast + // - words will need to make up for the move; rare words cannot + // - so, we multiply each factor with 1/refCount + // - core factors get weighed down a lot + // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as + // before + // - but forward pass weighs them down, so that all factors are in a similar numeric range + // - if it is required to be in a different range, the embeddings can still learn that, but + // more slowly + + auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); +#if 1 + auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask())); +#else // @TODO: this is dead code now, get rid of it + // experimental: hide inline-fix source tokens from cross attention + auto batchMask + = graph->constant({dimWidth, dimBatch, 1}, + inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); +#endif + // give the graph inputs readable names for debugging and ONNX + batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); + + return std::make_tuple(batchEmbeddings, batchMask); +} + +Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { + if(factoredVocab_) { + Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { + // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout + return selectedEmbs; + } else + return applyIndices(toWordIndexVector(words), shape); +} + +Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const +/*override final*/ { + ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); + auto embIdxExpr = E_->graph()->indices(embIdx); + embIdxExpr->set_name("data_" + + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? + auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() + // (test that separately) + if(!inference_) + selectedEmbs = dropout( + selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1}); + return selectedEmbs; +} + +// standard encoder word embeddings +/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { + // clang-format off + auto options = New<Options>( + "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], + "dimEmb", opt<int>("dim-emb"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" + : prefix_ + "_Wemb", + "fixed", embeddingFix_, + "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings + // clang-format on + if(options_->hasAndNotEmpty("embedding-vectors")) { + auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); + options->set( + "embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization")); + } + return New<Embedding>(graph_, options); +} + +// ULR word embeddings +/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const { + // clang-format off + return New<ULREmbedding>(graph_, New<Options>( + "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src + "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt + "dimUlrEmb", opt<int>("ulr-dim-emb"), + "dimEmb", opt<int>("dim-emb"), + "ulr-dropout", opt<float>("ulr-dropout"), + "dropout", dropoutEmbeddings_, + "inference", inference_, + "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"), + "ulrQueryFile", opt<std::string>("ulr-query-vectors"), + "ulrKeysFile", opt<std::string>("ulr-keys-vectors") + )); + // clang-format on +} + +// get embedding layer for this encoder or decoder +// This is lazy mostly because the constructors of the consuming objects are not +// guaranteed presently to have access to their graph. +Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { + if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy + if(embeddingLayers_.size() <= batchIndex_) + embeddingLayers_.resize(batchIndex_ + 1); + if(ulr) + embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR + else + embeddingLayers_[batchIndex_] = createEmbeddingLayer(); + } + return embeddingLayers_[batchIndex_]; +} + +} // namespace marian diff --git a/src/layers/embedding.h b/src/layers/embedding.h new file mode 100644 index 00000000..2fa7b78d --- /dev/null +++ b/src/layers/embedding.h @@ -0,0 +1,157 @@ +#pragma once +#include "generic.h" +#include "marian.h" + +namespace marian { + +class FactoredVocab; + +// A regular embedding layer. +// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). +// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in +// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. +class Embedding : public LayerBase, public IEmbeddingLayer { + Expr E_; + Ptr<FactoredVocab> factoredVocab_; + Expr multiRows(const Words& data, float dropProb) const; + bool inference_{false}; + +public: + Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options); + + std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const override final; + + Expr apply(const Words& words, const Shape& shape) const override final; + + Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final; +}; + +class ULREmbedding : public LayerBase, public IEmbeddingLayer { + std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members + bool inference_{false}; + +public: + ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) + : LayerBase(graph, options), inference_(opt<bool>("inference")) { + std::string name = "url_embed"; // opt<std::string>("prefix"); + int dimKeys = opt<int>("dimTgtVoc"); + int dimQueries = opt<int>("dimSrcVoc"); + int dimEmb = opt<int>("dimEmb"); + int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size + bool fixed = opt<bool>("fixed", false); + + // Embedding layer initialization should depend only on embedding size, hence fanIn=false + auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); + + std::string queryFile = opt<std::string>("ulrQueryFile"); + std::string keyFile = opt<std::string>("ulrKeysFile"); + bool trainTrans = opt<bool>("ulrTrainTransform", false); + if(!queryFile.empty() && !keyFile.empty()) { + initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); + name = "ulr_query"; + fixed = true; + auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed); + ulrEmbeddings_.push_back(query_embed); + // keys embeds + initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); + name = "ulr_keys"; + fixed = true; + auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed); + ulrEmbeddings_.push_back(key_embed); + // actual trainable embedding + initFunc = inits::glorotUniform(); + name = "ulr_embed"; + fixed = false; + auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim + ulrEmbeddings_.push_back(ulr_embed); + // init trainable src embedding + name = "ulr_src_embed"; + auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed); + ulrEmbeddings_.push_back(ulr_src_embed); + // ulr transformation matrix + // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall + // we make this to the fixed case only + if(trainTrans) { + initFunc = inits::glorotUniform(); + fixed = false; + } else { + initFunc = inits::eye(); // identity matrix + fixed = true; + } + name = "ulr_transform"; + auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed); + ulrEmbeddings_.push_back(ulrTransform); + + initFunc = inits::fromValue( + 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no + // universal embeddings - should be zero for top freq only + fixed = true; + name = "ulr_shared"; + auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed); + ulrEmbeddings_.push_back(share_embed); + } + } + + std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const override final { + auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb + auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb + auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb + auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb + auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb + auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 + int dimBatch = (int)subBatch->batchSize(); + int dimEmb = uniEmbed->shape()[-1]; + int dimWords = (int)subBatch->batchWidth(); + // D = K.A.QT + // dimm(K) = univ_tok_vocab*uni_embed_size + // dim A = uni_embed_size*uni_embed_size + // dim Q: uni_embed_size * total_merged_vocab_size + // dim D = univ_tok_vocab * total_merged_vocab_size + // note all above can be precombuted and serialized if A is not trainiable and during decoding + // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this + // minibatch from Q + auto embIdx = toWordIndexVector(subBatch->data()); + auto queryEmbeddings = rows(queryEmbed, embIdx); + auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings + auto alpha = rows(ulrSharable, embIdx); // extract sharable flags + auto qt = dot(queryEmbeddings, ulrTransform, false, false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb + auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]); + qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in + // magnitude with larger embeds sizes + auto z = dot(qt, keyEmbed, false, true); // query-key similarity + float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout + if(!inference_) + z = dropout(z, dropProb); + + float tau + = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature + // temperature in softmax is to control randomness of predictions + // high temperature Softmax outputs are more close to each other + // low temperatures the softmax become more similar to "hardmax" + auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? + auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE + auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast + auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb}); + auto graph = ulrEmbeddings_.front()->graph(); + auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask())); + if(!inference_) + batchEmbeddings = dropout(batchEmbeddings, + options_->get<float>("dropout-embeddings", 0.0f), + {batchEmbeddings->shape()[-3], 1, 1}); + return std::make_tuple(batchEmbeddings, batchMask); + } + + Expr apply(const Words& words, const Shape& shape) const override final { + return applyIndices(toWordIndexVector(words), shape); + } + + Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final { + embIdx; + shape; + ABORT("not implemented"); // @TODO: implement me + } +}; + +} // namespace marian diff --git a/src/layers/factory.h b/src/layers/factory.h index f9e4ddf9..f9e4ddf9 100755..100644 --- a/src/layers/factory.h +++ b/src/layers/factory.h diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index d44f4020..8e2ecfd7 100755..100644 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -1,609 +1,10 @@ #include "marian.h" -#include "layers/generic.h" +#include "data/factored_vocab.h" #include "layers/constructors.h" +#include "layers/generic.h" #include "layers/loss.h" -#include "data/factored_vocab.h" -#include "rnn/types.h" // for State::select() -#include "models/states.h" // for EncoderState #include "layers/lsh.h" +#include "models/states.h" // for EncoderState -namespace marian { - Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) - - Ptr<ExpressionGraph> Logits::graph() const { - ABORT_IF(logits_.empty(), "Empty logits object??"); - return logits_.front()->loss()->graph(); - } - - // This function assumes that the object holds one or more factor logits. - // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. - Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const { - LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto firstLogits = logits_.front()->loss(); - ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), - "Labels not matching logits shape ({} != {}, {})??", - labels.size() * firstLogits->shape()[-1], - firstLogits->shape().elements(), - firstLogits->shape()); - - // base case (no factors) - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return lossFn(firstLogits, indices(toWordIndexVector(labels))); - } - - auto numGroups = factoredVocab_->getNumGroups(); - - // split labels into individual factor labels - auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] - - //Expr indices = this->indices(toWordIndexVector(labels)); - // accumulate all CEs for all words that have the factor - // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. - Expr loss; - for (size_t g = 0; g < numGroups; g++) { - if (!logits_[g]) - continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] - const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) - auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply - auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor - auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) - // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. - auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] - if(loss) - factorLoss = cast(factorLoss, loss->value_type()); - factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor - loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] - } - return loss; - } - - // This function assumes this object holds a single factor that represents a rational loss (with count). - //Ptr<RationalLoss> Logits::getRationalLoss() const { - // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); - // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); - // return logits_.front(); - //} - - // get logits for one factor group - // For groupIndex == 0, the function also requires the shortlist if there is one. - Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - - auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] - - // normalize for decoding: - // - all secondary factors: subtract their max - // - lemma: add all maxes of applicable factors - if (groupIndex > 0) { - sel = sel - max(sel, -1); - } - else { - auto numGroups = getNumFactorGroups(); - for (size_t g = 1; g < numGroups; g++) { - auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 - auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>())); - sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 - } - } - - // if selIdx are given, then we must reshuffle accordingly - if (!hypIndices.empty()) // use the same function that shuffles decoder state - sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); - - return sel; - } - - // used for breakDown() only - // Index is flattened - Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - return logits_[groupIndex]->loss()->val(); - } - - // This function assumes that the object holds one or more factor logits, which are summed up - // into output-vocab logits according to the factored model (with correct normalization of factors). - // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. - // @TODO: remove altogether - Expr Logits::getLogits() const { - ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return getFactoredLogits(0); - } - -#ifdef FACTOR_FULL_EXPANSION - // compute normalized factor log probs - std::vector<Expr> logProbs(logits_.size()); - for (size_t g = 0; g < logits_.size(); g++) - logProbs[g] = logsoftmax(logits_[g]->loss()); - auto y = concatenate(logProbs, /*axis=*/ -1); - - // sum up the unit logits across factors for each target word - auto graph = y->graph(); - auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] - y = dot_csr( - y, // [B x U] - factorMatrix.shape, - graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), - graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), - graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), - /*transB=*/ true); // -> [B x V] - - // mask out gaps - auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] - y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); - - return y; -#else - ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible -#endif - } - - void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { - bool isValid = FactoredVocab::isFactorValid(factorIndex); - indices.push_back(isValid ? (WordIndex)factorIndex : 0); - masks.push_back((float)isValid); - } - - std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices - if (!factoredVocab_) { - ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); - return {MaskedFactorIndices(words)}; - } - auto numGroups = factoredVocab_->getNumGroups(); - std::vector<MaskedFactorIndices> res(numGroups); - for (size_t g = 0; g < numGroups; g++) { - auto& resg = res[g]; - resg.reserve(words.size()); - for (const auto& word : words) - resg.push_back(factoredVocab_->getFactor(word, g)); - } - return res; - } - - //// use first factor of each word to determine whether it has a specific factor - //std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 - // std::vector<float> res; - // res.reserve(words.size()); - // for (const auto& word : words) { - // auto lemma = factoredVocab_->getFactor(word, 0); - // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - // } - // return res; - //} - - // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor - // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas - std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 - size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); - std::vector<float> res; - res.reserve(n); - // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab - for (size_t i = 0; i < n; i++) { - auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); - res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); - } - return res; - } - - Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values - std::vector<Ptr<RationalLoss>> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count())); - return Logits(std::move(newLogits), factoredVocab_); - } - - Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const { - std::vector<Ptr<RationalLoss>> newLogits; - bool first = true; - for (const auto& l : logits_) { - newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others - first = false; - } - return Logits(std::move(newLogits), factoredVocab_); - } - - // @TODO: code dup with above; we can merge it into applyToRationalLoss() - Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ - std::vector<Ptr<RationalLoss>> newLogits; - for (const auto& l : logits_) - newLogits.emplace_back(New<RationalLoss>(l->loss(), count)); - return Logits(std::move(newLogits), factoredVocab_); - } - - namespace mlp { - /*private*/ void Output::lazyConstruct(int inputDim) { - // We must construct lazily since we won't know tying nor input dim in constructor. - if (Wt_) - return; - - // this option is only set in the decoder - if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { - auto k = opt<std::vector<int>>("output-approx-knn")[0]; - auto nbits = opt<std::vector<int>>("output-approx-knn")[1]; - lsh_ = New<LSH>(k, nbits); - } - - auto name = options_->get<std::string>("prefix"); - auto numOutputClasses = options_->get<int>("dim"); - - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); - if (factoredVocab_) { - numOutputClasses = (int)factoredVocab_->factorVocabSize(); - LOG_ONCE(info, "[embedding] Factored outputs enabled"); - } - - if(tiedParam_) { - Wt_ = tiedParam_; - } else { - if (graph_->get(name + "_W")) { // support of legacy models that did not transpose - Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); - isLegacyUntransposedW = true; - } - else // this is the regular case: - Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); - } - - if(hasBias_) - b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); - - /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); - ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); - if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix -#define HARDMAX_HACK -#ifdef HARDMAX_HACK - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number -#endif - auto range = factoredVocab_->getGroupRange(0); - auto lemmaVocabDim = (int)(range.second - range.first); - auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length - lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed - } - } - - Logits Output::applyAsLogits(Expr input) /*override final*/ { - lazyConstruct(input->shape()[-1]); - - auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { - if(b) - return affine(x, W, b, transA, transB); - else - return dot(x, W, transA, transB); - }; - - auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { - if(lsh_) { - ABORT_IF( transA, "Transposed query not supported for LSH"); - ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); - return lsh_->apply(x, W, b); // knows how to deal with undefined bias - } else { - return affineOrDot(x, W, b, transA, transB); - } - }; - - if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed - cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); - if(hasBias_) - cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); - } - - if (factoredVocab_) { - auto graph = input->graph(); - - // project each factor separately - auto numGroups = factoredVocab_->getNumGroups(); - std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors) - Expr input1 = input; // [B... x D] - Expr Plemma = nullptr; // used for lemmaDimEmb=-1 - Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 - for (size_t g = 0; g < numGroups; g++) { - auto range = factoredVocab_->getGroupRange(g); - if (g > 0 && range.first == range.second) // empty entry - continue; - ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g); - // slice this group's section out of W_ - Expr factorWt, factorB; - if (g == 0 && shortlist_) { - factorWt = cachedShortWt_; - factorB = cachedShortb_; - } - else { - factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); - if(hasBias_) - factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); - } - /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); - if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) - LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); - // this mimics one transformer layer - // - attention over two inputs: - // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas. - // - input = hidden state FF(h_enc+h_dec) - // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention) - // - multi-head to allow for multiple conditions to be modeled - // - add & norm, for gradient flow and scaling - // - FF layer --this is expensive; it is per-factor - // multi-head attention - int inputDim = input->shape()[-1]; - int heads = 8; - auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g); - auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform()); - auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform()); - auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform()); - auto toMultiHead = [&](Expr x, int heads) { - const auto& shape = x->shape(); - int inputDim = shape[-1]; - int otherDim = shape.elements() / inputDim; - ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads); - return reshape(x, { otherDim, heads, 1, inputDim / heads }); - }; - input1 = inputLemma; - auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query - auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax. - auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values - auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other - auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] - auto sm = sigmoid(zm); // [B... x H x 1] - auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] - auto r = reshape(rm, input->shape()); // [B... x D] - // add & norm - input1 = r + input1; - input1 = layerNorm(input1, name + "_att"); - // FF layer - auto ffnDropProb = 0.1f; // @TODO: get as a parameter - auto ffnDim = inputDim * 2; // @TODO: get as a parameter - auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb); - f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); - // add & norm - input1 = f + input1; - input1 = layerNorm(input1, name + "_ffn"); - } - // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix - Expr factorLogits; - if(g == 0) - factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - else - factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits - - // optionally add lemma-dependent bias - if (Plemma) { // [B... x U0] - int lemmaVocabDim = Plemma->shape()[-1]; - int factorVocabDim = factorLogits->shape()[-1]; - auto name = options_->get<std::string>("prefix"); - Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma - auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] - factorLogits = factorLogits + b; - } - allLogits[g] = New<RationalLoss>(factorLogits, nullptr); - // optionally add a soft embedding of lemma back to create some lemma dependency - // @TODO: if this works, move it into lazyConstruct - if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure - LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); - // get expected lemma embedding vector - auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set - auto factorSoftmax = exp(factorLogSoftmax); - inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max - LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); - // get max-lemma embedding vector - auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set - auto factorHardmax = eq(factorLogits, maxVal); - inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] - } - else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias - ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); - LOG_ONCE(info, "[embedding] using lemma-dependent bias"); - auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) - auto z = /*stopGradient*/(factorLogSoftmax); - Plemma = exp(z); // [B... x U] - } - else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix - LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); - // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE - auto factorLogSoftmax = logsoftmax(factorLogits); - auto factorSoftmax = exp(factorLogSoftmax); -#ifdef HARDMAX_HACK - bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation) - if (hardmax) { - lemmaDimEmb = lemmaDimEmb & 0xfffffffe; - LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); - auto maxVal = max(factorSoftmax, -1); - factorSoftmax = eq(factorSoftmax, maxVal); - } -#endif - // re-embedding lookup, soft-indexed by softmax - if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix - cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); - auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L] - // project it back to regular hidden dim - int inputDim = input1->shape()[-1]; - auto name = options_->get<std::string>("prefix"); - // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1 - Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension - auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] - // augment the original hidden vector with this additional information - input1 = input1 + f; - } - } - return Logits(std::move(allLogits), factoredVocab_); - } else if (shortlist_) { - return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } else { - return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - } - } - } - - Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options), inference_(opt<bool>("inference")) { - std::string name = opt<std::string>("prefix"); - int dimVoc = opt<int>("dimVocab"); - int dimEmb = opt<int>("dimEmb"); - - bool fixed = opt<bool>("fixed", false); - - factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); - if (factoredVocab_) { - dimVoc = (int)factoredVocab_->factorVocabSize(); - LOG_ONCE(info, "[embedding] Factored embeddings enabled"); - } - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length - - if (options_->has("embFile")) { - std::string file = opt<std::string>("embFile"); - if (!file.empty()) { - bool norm = opt<bool>("normalization", false); - initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm); - } - } - - E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed); - } - - // helper to embed a sequence of words (given as indices) via factored embeddings - Expr Embedding::multiRows(const Words& data, float dropProb) const { - auto graph = E_->graph(); - auto factoredData = factoredVocab_->csr_rows(data); - // multi-hot factor vectors are represented as a sparse CSR matrix - // [row index = word position index] -> set of factor indices for word at this position - ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??"); - // the CSR matrix is passed in pieces - auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights)); - auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32); - auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32); - // apply dropout - // We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors. - if(!inference_) - weights = dropout(weights, dropProb); - // perform the product - return csr_dot(factoredData.shape, weights, indices, offsets, E_); - } - - std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ { - auto graph = E_->graph(); - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = E_->shape()[-1]; - int dimWidth = (int)subBatch->batchWidth(); - - // factored embeddings: - // - regular: - // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D] - // - factored: - // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space) - // - each row of M contains the set of factors for one word => we want a CSR matrix - // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D] - // - first compute x @ M on the CPU - // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()): - // - shape (U, specifically) not actually needed here - // - foreach input x[i] - // - locate row M[i,*] - // - copy through its index values (std::vector<push_back>) - // - create a matching ones vector (we can keep growing) - // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x) - // - CSR matrix product with E - // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU) - // - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()). - // - weighting: - // - core factors' gradients are sums over all words that use the factors; - // - core factors' embeddings move very fast - // - words will need to make up for the move; rare words cannot - // - so, we multiply each factor with 1/refCount - // - core factors get weighed down a lot - // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before - // - but forward pass weighs them down, so that all factors are in a similar numeric range - // - if it is required to be in a different range, the embeddings can still learn that, but more slowly - - auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb}); -#if 1 - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->mask())); -#else // @TODO: this is dead code now, get rid of it - // experimental: hide inline-fix source tokens from cross attention - auto batchMask = graph->constant({dimWidth, dimBatch, 1}, - inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); -#endif - // give the graph inputs readable names for debugging and ONNX - batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); - - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ { - if (factoredVocab_) { - Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - //selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout - return selectedEmbs; - } - else - return applyIndices(toWordIndexVector(words), shape); - } - - Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ { - ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); - auto embIdxExpr = E_->graph()->indices(embIdx); - embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? - auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] - // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) - if(!inference_) - selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); - return selectedEmbs; - } - - // standard encoder word embeddings - /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const { - auto options = New<Options>( - "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_], - "dimEmb", opt<int>("dim-emb"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb", - "fixed", embeddingFix_, - "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings - if(options_->hasAndNotEmpty("embedding-vectors")) { - auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); - options->set( - "embFile", embFiles[batchIndex_], - "normalization", opt<bool>("embedding-normalization")); - } - return New<Embedding>(graph_, options); - } - - // ULR word embeddings - /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const { - return New<ULREmbedding>(graph_, New<Options>( - "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src - "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt - "dimUlrEmb", opt<int>("ulr-dim-emb"), - "dimEmb", opt<int>("dim-emb"), - "ulr-dropout", opt<float>("ulr-dropout"), - "dropout", dropoutEmbeddings_, - "inference", inference_, - "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"), - "ulrQueryFile", opt<std::string>("ulr-query-vectors"), - "ulrKeysFile", opt<std::string>("ulr-keys-vectors"))); - } - - // get embedding layer for this encoder or decoder - // This is lazy mostly because the constructors of the consuming objects are not - // guaranteed presently to have access to their graph. - Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const { - if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy - if (embeddingLayers_.size() <= batchIndex_) - embeddingLayers_.resize(batchIndex_ + 1); - if (ulr) - embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR - else - embeddingLayers_[batchIndex_] = createEmbeddingLayer(); - } - return embeddingLayers_[batchIndex_]; - } -} // namespace marian +namespace marian {} // namespace marian diff --git a/src/layers/generic.h b/src/layers/generic.h index f47bb45e..89f5c1e9 100755..100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -5,12 +5,14 @@ #include "data/shortlist.h" #include "layers/factory.h" -namespace marian { namespace mlp { - /** - * @brief Activation functions - */ - enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; -}} +namespace marian { +namespace mlp { +/** + * @brief Activation functions + */ +enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish }; +} // namespace mlp +} // namespace marian namespace marian { @@ -23,8 +25,7 @@ protected: Ptr<Options> options_; public: - LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : graph_(graph), options_(options) {} + LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {} template <typename T> T opt(const std::string key) const { @@ -42,7 +43,7 @@ struct IUnaryLayer { virtual ~IUnaryLayer() {} virtual Expr apply(Expr) = 0; virtual Expr apply(const std::vector<Expr>& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub return apply(es.front()); } }; @@ -54,7 +55,8 @@ struct IHasShortList { // Embedding from corpus sub-batch to (emb, mask) struct IEmbeddingLayer { - virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0; + virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply( + Ptr<data::SubBatch> subBatch) const = 0; virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0; @@ -63,28 +65,29 @@ struct IEmbeddingLayer { virtual ~IEmbeddingLayer() {} }; -// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index) +// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream +// index) class EncoderDecoderLayerBase : public LayerBase { protected: const std::string prefix_; const bool embeddingFix_; - const float dropoutEmbeddings_; // this drops out full embedding vectors + const float dropoutEmbeddings_; // this drops out full embedding vectors const bool inference_; const size_t batchIndex_; - mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created) + mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created) - EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph, - Ptr<Options> options, - const std::string& prefix, + EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph, + Ptr<Options> options, + const std::string& prefix, size_t batchIndex, float dropoutEmbeddings, - bool embeddingFix) : - LayerBase(graph, options), - prefix_(options->get<std::string>("prefix", prefix)), - embeddingFix_(embeddingFix), - dropoutEmbeddings_(dropoutEmbeddings), - inference_(options->get<bool>("inference", false)), - batchIndex_(options->get<size_t>("index", batchIndex)) {} + bool embeddingFix) + : LayerBase(graph, options), + prefix_(options->get<std::string>("prefix", prefix)), + embeddingFix_(embeddingFix), + dropoutEmbeddings_(dropoutEmbeddings), + inference_(options->get<bool>("inference", false)), + batchIndex_(options->get<size_t>("index", batchIndex)) {} virtual ~EncoderDecoderLayerBase() {} @@ -97,78 +100,11 @@ public: Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const; }; -class FactoredVocab; - -// To support factors, any output projection (that is followed by a softmax) must -// retain multiple outputs, one for each factor. Such layer returns not a single Expr, -// but a Logits object that contains multiple. -// This allows to compute softmax values in a factored manner, where we never create -// a fully expanded list of all factor combinations. -class RationalLoss; -class Logits { -public: - Logits() {} - explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor - : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const; - Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values - - struct MaskedFactorIndices { - std::vector<WordIndex> indices; // factor index, or 0 if masked - std::vector<float> masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ -private: - // helper functions - Ptr<ExpressionGraph> graph() const; - Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const; -private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] - Ptr<FactoredVocab> factoredVocab_; -}; - -// Unary function that returns a Logits object -// Also implements IUnaryLayer, since Logits can be cast to Expr. -// This interface is implemented by all layers that are of the form of a unary function -// that returns multiple logits, to support factors. -struct IUnaryLogitLayer : public IUnaryLayer { - virtual Logits applyAsLogits(Expr) = 0; - virtual Logits applyAsLogits(const std::vector<Expr>& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub - return applyAsLogits(es.front()); - } - virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } - virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); } -}; - namespace mlp { class Dense : public LayerBase, public IUnaryLayer { public: - Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options) {} + Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {} Expr apply(const std::vector<Expr>& inputs) override { ABORT_IF(inputs.empty(), "No inputs"); @@ -190,21 +126,17 @@ public: if(inputs.size() > 1) num = std::to_string(i); - Expr W = g->param( - name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); + Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform()); Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros()); if(useLayerNorm) { if(useNematusNorm) { - auto ln_s = g->param( - name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); + auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f)); auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros()); - outputs.push_back( - layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); + outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); } else { - auto gamma = g->param( - name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); + auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0)); outputs.push_back(layerNorm(dot(in, W), gamma, b)); } @@ -231,241 +163,35 @@ public: Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); } }; -} // namespace mlp - -class LSH; - -namespace mlp { - -class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { -private: - // parameters held by this layer - Expr Wt_; // weight matrix is stored transposed for efficiency - Expr b_; - Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] - bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form - bool hasBias_{true}; - - Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) - Expr cachedShortb_; // these match the current value of shortlist_ - Expr cachedShortLemmaEt_; - Ptr<FactoredVocab> factoredVocab_; - - // optional parameters set/updated after construction - Expr tiedParam_; - Ptr<data::Shortlist> shortlist_; - Ptr<LSH> lsh_; - - void lazyConstruct(int inputDim); -public: - Output(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options), - hasBias_{!options->get<bool>("output-omit-bias", false)} { - clear(); - } - - void tieTransposed(Expr tied) { - if (Wt_) - ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created"); - else - tiedParam_ = tied; - } - - void setShortlist(Ptr<data::Shortlist> shortlist) override final { - if (shortlist_) - ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()"); - else { - ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??"); - shortlist_ = shortlist; - } - // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() - } - - // this is expected to be called in sync with graph->clear(), which invalidates - // cachedShortWt_ etc. in the graph's short-term cache - void clear() override final { - shortlist_ = nullptr; - cachedShortWt_ = nullptr; - cachedShortb_ = nullptr; - cachedShortLemmaEt_ = nullptr; - } - - Logits applyAsLogits(Expr input) override final; -}; - } // namespace mlp -// A regular embedding layer. -// Note that this also applies dropout if the option is passed (pass 0 when in inference mode). -// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in -// EncoderDecoderLayerBase, which knows to pass on all required parameters from options. -class Embedding : public LayerBase, public IEmbeddingLayer { - Expr E_; - Ptr<FactoredVocab> factoredVocab_; - Expr multiRows(const Words& data, float dropProb) const; - bool inference_{false}; - -public: - Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options); - - std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final; - - Expr apply(const Words& words, const Shape& shape) const override final; - - Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final; -}; - -class ULREmbedding : public LayerBase, public IEmbeddingLayer { - std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members - bool inference_{false}; - -public: - ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) - : LayerBase(graph, options), inference_(opt<bool>("inference")) { - std::string name = "url_embed"; //opt<std::string>("prefix"); - int dimKeys = opt<int>("dimTgtVoc"); - int dimQueries = opt<int>("dimSrcVoc"); - int dimEmb = opt<int>("dimEmb"); - int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size - bool fixed = opt<bool>("fixed", false); - - // Embedding layer initialization should depend only on embedding size, hence fanIn=false - auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); - - std::string queryFile = opt<std::string>("ulrQueryFile"); - std::string keyFile = opt<std::string>("ulrKeysFile"); - bool trainTrans = opt<bool>("ulrTrainTransform", false); - if (!queryFile.empty() && !keyFile.empty()) { - initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false); - name = "ulr_query"; - fixed = true; - auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(query_embed); - // keys embeds - initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false); - name = "ulr_keys"; - fixed = true; - auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(key_embed); - // actual trainable embedding - initFunc = inits::glorotUniform(); - name = "ulr_embed"; - fixed = false; - auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim - ulrEmbeddings_.push_back(ulr_embed); - // init trainable src embedding - name = "ulr_src_embed"; - auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulr_src_embed); - // ulr transformation matrix - //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only - if (trainTrans) { - initFunc = inits::glorotUniform(); - fixed = false; - } - else - { - initFunc = inits::eye(); // identity matrix - fixed = true; - } - name = "ulr_transform"; - auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed); - ulrEmbeddings_.push_back(ulrTransform); - - initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only - fixed = true; - name = "ulr_shared"; - auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed); - ulrEmbeddings_.push_back(share_embed); - } - } - - std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final { - auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb - auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb - auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb - auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb - auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb - auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1 - int dimBatch = (int)subBatch->batchSize(); - int dimEmb = uniEmbed->shape()[-1]; - int dimWords = (int)subBatch->batchWidth(); - // D = K.A.QT - // dimm(K) = univ_tok_vocab*uni_embed_size - // dim A = uni_embed_size*uni_embed_size - // dim Q: uni_embed_size * total_merged_vocab_size - // dim D = univ_tok_vocab * total_merged_vocab_size - // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD) - // here we need to handle the mini-batch - // extract raws corresponding to Xs in this minibatch from Q - auto embIdx = toWordIndexVector(subBatch->data()); - auto queryEmbeddings = rows(queryEmbed, embIdx); - auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings - auto alpha = rows(ulrSharable, embIdx); // extract sharable flags - auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb - auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]); - qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes - auto z = dot(qt, keyEmbed, false, true); // query-key similarity - float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout - if(!inference_) - z = dropout(z, dropProb); - - float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature - // temperature in softmax is to control randomness of predictions - // high temperature Softmax outputs are more close to each other - // low temperatures the softmax become more similar to "hardmax" - auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ?? - auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE - auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast - auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb }); - auto graph = ulrEmbeddings_.front()->graph(); - auto batchMask = graph->constant({ dimWords, dimBatch, 1 }, - inits::fromVector(subBatch->mask())); - if(!inference_) - batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1}); - return std::make_tuple(batchEmbeddings, batchMask); - } - - Expr apply(const Words& words, const Shape& shape) const override final { - return applyIndices(toWordIndexVector(words), shape); - } - - Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final { - embIdx; shape; - ABORT("not implemented"); // @TODO: implement me - } -}; - // --- a few layers with built-in parameters created on the fly, without proper object // @TODO: change to a proper layer object // like affine() but with built-in parameters, activation, and dropout -static inline -Expr denseInline(Expr x, - std::string prefix, - std::string suffix, - int outDim, - Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(), - const std::function<Expr(Expr)>& actFn = nullptr, - float dropProb = 0.0f) -{ +static inline Expr denseInline(Expr x, + std::string prefix, + std::string suffix, + int outDim, + Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(), + const std::function<Expr(Expr)>& actFn = nullptr, + float dropProb = 0.0f) { auto graph = x->graph(); - auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform()); - auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros()); + auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform()); + auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros()); x = affine(x, W, b); - if (actFn) + if(actFn) x = actFn(x); - x = dropout(x, dropProb); // @TODO: check for infernce? + x = dropout(x, dropProb); // @TODO: check for infernce? return x; } -static inline -Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { +static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) { int dimModel = x->shape()[-1]; - auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones()); - auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros()); + auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones()); + auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros()); return marian::layerNorm(x, scale, bias, 1e-6f); } diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h index f08d3f09..f08d3f09 100755..100644 --- a/src/layers/guided_alignment.h +++ b/src/layers/guided_alignment.h diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp new file mode 100644 index 00000000..8c4d69bd --- /dev/null +++ b/src/layers/logits.cpp @@ -0,0 +1,245 @@ +#include "logits.h" +#include "data/factored_vocab.h" +#include "loss.h" +#include "rnn/types.h" // for State::select() + +namespace marian { +Logits::Logits(Expr logits) + : Logits(New<RationalLoss>(logits, nullptr)) { +} // single-output constructor from Expr only (RationalLoss has no count) + +Ptr<ExpressionGraph> Logits::graph() const { + ABORT_IF(logits_.empty(), "Empty logits object??"); + return logits_.front()->loss()->graph(); +} + +// This function assumes that the object holds one or more factor logits. +// It applies the supplied loss function to each, and then returns the aggregate loss over all +// factors. +Expr Logits::applyLossFunction( + const Words& labels, + const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const { + LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto firstLogits = logits_.front()->loss(); + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + "Labels not matching logits shape ({} != {}, {})??", + labels.size() * firstLogits->shape()[-1], + firstLogits->shape().elements(), + firstLogits->shape()); + + // base case (no factors) + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return lossFn(firstLogits, indices(toWordIndexVector(labels))); + } + + auto numGroups = factoredVocab_->getNumGroups(); + + // split labels into individual factor labels + auto allMaskedFactoredLabels + = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] + + // Expr indices = this->indices(toWordIndexVector(labels)); + // accumulate all CEs for all words that have the factor + // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. + Expr loss; + for(size_t g = 0; g < numGroups; g++) { + if(!logits_[g]) + continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] + // clang-format off + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + // clang-format on + if(loss) + factorLoss = cast(factorLoss, loss->value_type()); + factorLoss + = factorLoss + * cast( + reshape(factorMask, factorLoss->shape()), + factorLoss->value_type()); // mask out factor for words that do not have that factor + loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] + } + return loss; +} + +// This function assumes this object holds a single factor that represents a rational loss (with +// count). +// Ptr<RationalLoss> Logits::getRationalLoss() const { +// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on +// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational +// loss without count"); return logits_.front(); +//} + +// get logits for one factor group +// For groupIndex == 0, the function also requires the shortlist if there is one. +Expr Logits::getFactoredLogits(size_t groupIndex, + Ptr<data::Shortlist> shortlist /*= nullptr*/, + const std::vector<IndexType>& hypIndices /*= {}*/, + size_t beamSize /*= 0*/) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + + // normalize for decoding: + // - all secondary factors: subtract their max + // - lemma: add all maxes of applicable factors + if(groupIndex > 0) { + sel = sel - max(sel, -1); + } else { + auto numGroups = getNumFactorGroups(); + for(size_t g = 1; g < numGroups; g++) { + auto factorMaxima = max(logits_[g]->loss(), + -1); // we cast since loss is likely ce-loss which has type float32 + auto factorMasks = constant( + getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>())); + sel = sel + + cast(factorMaxima, sel->value_type()) + * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor + // get multiplied with 0 + } + } + + // if selIdx are given, then we must reshuffle accordingly + if(!hypIndices.empty()) // use the same function that shuffles decoder state + sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + + return sel; +} + +// used for breakDown() only +// Index is flattened +Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + return logits_[groupIndex]->loss()->val(); +} + +// This function assumes that the object holds one or more factor logits, which are summed up +// into output-vocab logits according to the factored model (with correct normalization of factors). +// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. +// @TODO: remove altogether +Expr Logits::getLogits() const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return getFactoredLogits(0); + } + +#ifdef FACTOR_FULL_EXPANSION + // compute normalized factor log probs + std::vector<Expr> logProbs(logits_.size()); + for(size_t g = 0; g < logits_.size(); g++) + logProbs[g] = logsoftmax(logits_[g]->loss()); + auto y = concatenate(logProbs, /*axis=*/-1); + + // clang-format off + // sum up the unit logits across factors for each target word + auto graph = y->graph(); + auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] + y = dot_csr( + y, // [B x U] + factorMatrix.shape, + graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), + graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), + /*transB=*/true); // -> [B x V] + // clang-format on + + // mask out gaps + auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] + y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask)); + + return y; +#else + ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible +#endif +} + +void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { + bool isValid = FactoredVocab::isFactorValid(factorIndex); + indices.push_back(isValid ? (WordIndex)factorIndex : 0); + masks.push_back((float)isValid); +} + +std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) + const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices + if(!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return {MaskedFactorIndices(words)}; + } + auto numGroups = factoredVocab_->getNumGroups(); + std::vector<MaskedFactorIndices> res(numGroups); + for(size_t g = 0; g < numGroups; g++) { + auto& resg = res[g]; + resg.reserve(words.size()); + for(const auto& word : words) + resg.push_back(factoredVocab_->getFactor(word, g)); + } + return res; +} + +//// use first factor of each word to determine whether it has a specific factor +// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 +// for words that do have this factor; else 0 +// std::vector<float> res; +// res.reserve(words.size()); +// for (const auto& word : words) { +// auto lemma = factoredVocab_->getFactor(word, 0); +// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); +// } +// return res; +//} + +// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor +// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas +std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) + const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n + = indices.empty() + ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) + : indices.size(); + std::vector<float> res; + res.reserve(n); + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this + // into FactoredVocab + for(size_t i = 0; i < n; i++) { + auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + } + return res; +} + +Logits Logits::applyUnaryFunction( + const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values + std::vector<Ptr<RationalLoss>> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count())); + return Logits(std::move(newLogits), factoredVocab_); +} + +Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, + const std::function<Expr(Expr)>& fother) const { + std::vector<Ptr<RationalLoss>> newLogits; + bool first = true; + for(const auto& l : logits_) { + newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()), + l->count())); // f1 for first, fother for all others + first = false; + } + return Logits(std::move(newLogits), factoredVocab_); +} + +// @TODO: code dup with above; we can merge it into applyToRationalLoss() +Logits Logits::withCounts( + const Expr& count) const { // create new Logits with 'count' implanted into all logits_ + std::vector<Ptr<RationalLoss>> newLogits; + for(const auto& l : logits_) + newLogits.emplace_back(New<RationalLoss>(l->loss(), count)); + return Logits(std::move(newLogits), factoredVocab_); +} +} // namespace marian diff --git a/src/layers/logits.h b/src/layers/logits.h new file mode 100644 index 00000000..c61a9e74 --- /dev/null +++ b/src/layers/logits.h @@ -0,0 +1,106 @@ +#pragma once + +#include "data/shortlist.h" +#include "generic.h" +#include "marian.h" + +namespace marian { + +class FactoredVocab; + +// To support factors, any output projection (that is followed by a softmax) must +// retain multiple outputs, one for each factor. Such layer returns not a single Expr, +// but a Logits object that contains multiple. +// This allows to compute softmax values in a factored manner, where we never create +// a fully expanded list of all factor combinations. +class RationalLoss; +class Logits { +public: + Logits() {} + explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor + logits_.push_back(logits); + } + explicit Logits( + Expr logits); // single-output constructor from Expr only (RationalLoss has no count) + Logits(std::vector<Ptr<RationalLoss>>&& logits, + Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor + : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} + Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + Expr getFactoredLogits( + size_t groupIndex, + Ptr<data::Shortlist> shortlist = nullptr, + const std::vector<IndexType>& hypIndices = {}, + size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle + // Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that + Expr applyLossFunction( + const Words& labels, + const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const; + Logits applyUnaryFunction( + const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values + Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, + const std::function<Expr(Expr)>& fother) + const; // clone this but apply f1 to first and fother to to all other values + + struct MaskedFactorIndices { + std::vector<WordIndex> indices; // factor index, or 0 if masked + std::vector<float> masks; + void reserve(size_t n) { + indices.reserve(n); + masks.reserve(n); + } + void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 + // for invalid entries + MaskedFactorIndices() {} + MaskedFactorIndices(const Words& words) { + indices = toWordIndexVector(words); + } // we can leave masks uninitialized for this special use case + }; + std::vector<MaskedFactorIndices> factorizeWords( + const Words& words) const; // breaks encoded Word into individual factor indices + Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only + size_t getNumFactorGroups() const { return logits_.size(); } + bool empty() const { return logits_.empty(); } + Logits withCounts( + const Expr& count) const; // create new Logits with 'count' implanted into all logits_ +private: + // helper functions + Ptr<ExpressionGraph> graph() const; + Expr constant(const Shape& shape, const std::vector<float>& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { + return graph()->constant(shape, inits::fromVector(data)); + } + template <typename T> + Expr constant(const std::vector<T>& data) const { + return constant(Shape{(int)data.size()}, data); + } // same as constant() but assuming vector + Expr indices(const std::vector<uint32_t>& data) const { + return graph()->indices(data); + } // actually the same as constant(data) for this data type + std::vector<float> getFactorMasks(size_t factorGroup, + const std::vector<WordIndex>& indices) const; + +private: + // members + // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just + // by the Expr + std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group] + Ptr<FactoredVocab> factoredVocab_; +}; + +// Unary function that returns a Logits object +// Also implements IUnaryLayer, since Logits can be cast to Expr. +// This interface is implemented by all layers that are of the form of a unary function +// that returns multiple logits, to support factors. +struct IUnaryLogitLayer : public IUnaryLayer { + virtual Logits applyAsLogits(Expr) = 0; + virtual Logits applyAsLogits(const std::vector<Expr>& es) { + ABORT_IF(es.size() > 1, "Not implemented"); // simple stub + return applyAsLogits(es.front()); + } + virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } + virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); } +}; + +} // namespace marian diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp index 67d38832..695276af 100755..100644 --- a/src/layers/loss.cpp +++ b/src/layers/loss.cpp @@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) { bool wordScores = options->get<bool>("word-scores", false); return New<RescorerLoss>(wordScores); } else if(unlikelihood) { - ABORT_IF(!options->hasAndNotEmpty("data-weighting") - && options->get<std::string>("data-weighting-type") != "word", - "Unlikelihood loss training requires error annotation in form of per-target-label scores"); - return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting - } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum? + ABORT_IF( + !options->hasAndNotEmpty("data-weighting") + && options->get<std::string>("data-weighting-type") != "word", + "Unlikelihood loss training requires error annotation in form of per-target-label scores"); + return New<SequenceUnlikelihoodLoss>( + smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on + // values given for data-weighting + } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. + // E.g. what about ce-sum? return New<CrossEntropyLoss>(smoothing, factorWeight); } } // see loss.h for detailed explanations of each class Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) { - std::string multiLossType = options->get<std::string>("multi-loss-type", "sum"); - if(multiLossType == "sum") // sum of sums - return New<SumMultiRationalLoss>(); - else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale - return New<ScaledMultiRationalLoss>(); - else if(multiLossType == "mean") // sum of means - return New<MeanMultiRationalLoss>(); - else - ABORT("Unknown multi-loss-type {}", multiLossType); + std::string multiLossType = options->get<std::string>("multi-loss-type", "sum"); + if(multiLossType == "sum") // sum of sums + return New<SumMultiRationalLoss>(); + else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale + return New<ScaledMultiRationalLoss>(); + else if(multiLossType == "mean") // sum of means + return New<MeanMultiRationalLoss>(); + else + ABORT("Unknown multi-loss-type {}", multiLossType); } } // namespace marian diff --git a/src/layers/loss.h b/src/layers/loss.h index d7bc19e4..c662f991 100755..100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,8 +1,8 @@ #pragma once -#include "graph/expression_operators.h" -#include "layers/generic.h" // for Logits (Frank's factor hack) #include "data/types.h" +#include "graph/expression_operators.h" +#include "layers/logits.h" // for Logits (Frank's factor hack) namespace marian { @@ -22,21 +22,18 @@ namespace marian { */ class RationalLoss { protected: - Expr loss_; // numerator - Expr count_; // denominator + Expr loss_; // numerator + Expr count_; // denominator - RationalLoss() = default; // protected + RationalLoss() = default; // protected public: - RationalLoss(Expr loss, Expr count) - : loss_(loss), count_(count) {} + RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {} RationalLoss(Expr loss, float count) - : loss_(loss), - count_(constant_like(loss, inits::fromValue(count))) {} + : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {} - RationalLoss(const RationalLoss& other) - : loss_(other.loss_), count_(other.count_) {} + RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {} virtual ~RationalLoss() = default; @@ -50,7 +47,7 @@ public: } template <typename T> - T loss() const { // this will fail if loss is not a single value + T loss() const { // this will fail if loss is not a single value ABORT_IF(!loss_, "Loss has not been defined"); return loss_->val()->scalar<T>(); } @@ -65,7 +62,7 @@ public: } template <typename T> - T count() const { // this will fail if loss is not a single value + T count() const { // this will fail if loss is not a single value ABORT_IF(!count_, "Labels have not been defined"); return count_->val()->scalar<T>(); } @@ -85,21 +82,21 @@ public: * RationalLoss object. */ struct StaticLoss { - float loss; // numerator - float count; // denominator + float loss; // numerator + float count; // denominator StaticLoss() : loss(0.f), count(0.f) {} StaticLoss(const RationalLoss& dynamic) - : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {} + : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {} - StaticLoss operator +(const StaticLoss& other) const { + StaticLoss operator+(const StaticLoss& other) const { StaticLoss res(*this); res += other; return res; } - StaticLoss& operator +=(const StaticLoss& other) { + StaticLoss& operator+=(const StaticLoss& other) { loss = loss + other.loss; count = count + other.count; return *this; @@ -139,32 +136,21 @@ protected: public: MultiRationalLoss() : RationalLoss() {} - MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { - push_back(rl); - } + MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); } virtual void push_back(const RationalLoss& current) { - loss_ = accumulateLoss(current); - count_ = accumulateCount(current); + loss_ = accumulateLoss(current); + count_ = accumulateCount(current); partialLosses_.push_back(current); } - const RationalLoss& operator[](size_t i) { - return partialLosses_[i]; - } + const RationalLoss& operator[](size_t i) { return partialLosses_[i]; } - auto begin() -> decltype(partialLosses_.begin()) const { - return partialLosses_.begin(); - } + auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); } - auto end() -> decltype(partialLosses_.end()) const { - return partialLosses_.end(); - } - - size_t size() const { - return partialLosses_.size(); - } + auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); } + size_t size() const { return partialLosses_.size(); } }; /** @@ -212,17 +198,19 @@ private: virtual Expr accumulateLoss(const RationalLoss& current) override { if(loss_) { const auto& first = partialLosses_.front(); - return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss + return loss_ + + current.loss() * first.count() + / current.count(); // scale up/down to match scale of first loss } else { - return current.loss(); // first reference loss, keeps to scale with this one + return current.loss(); // first reference loss, keeps to scale with this one } } virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) { - return count_; // Keep first label count // or: count_ + first.count() / current.count(); + return count_; // Keep first label count // or: count_ + first.count() / current.count(); } else { - return current.count(); // This is the first loss + return current.count(); // This is the first loss } } @@ -253,9 +241,10 @@ private: virtual Expr accumulateCount(const RationalLoss& current) override { if(count_) - return count_; // keep the existing '1' + return count_; // keep the existing '1' else - return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ + return current.count()->graph()->ones( + {1}, current.loss()->value_type()); // just '1' as labels are factored into loss_ } public: @@ -279,18 +268,21 @@ class LabelwiseLoss { protected: std::vector<int> axes_; - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) = 0; + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) + = 0; // label counts are available, reduce together with loss to obtain counts RationalLoss reduce(Expr loss, Expr labels) { ABORT_IF(!loss, "Loss has not been computed"); ABORT_IF(!labels, "Labels have not been computed"); - Expr lossSum = cast(loss, Type::float32); // accumulate in float32 - Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 + Expr lossSum = cast(loss, Type::float32); // accumulate in float32 + Expr labelsSum = cast(labels, Type::float32); // accumulate in float32 for(int i = 0; i < axes_.size(); ++i) { - lossSum = sum(lossSum, axes_[i]); + lossSum = sum(lossSum, axes_[i]); labelsSum = sum(labelsSum, axes_[i]); } @@ -301,7 +293,7 @@ protected: RationalLoss reduce(Expr loss) { ABORT_IF(!loss, "Loss has not been computed"); - Expr lossSum = cast(loss, Type::float32); + Expr lossSum = cast(loss, Type::float32); for(int i = 0; i < axes_.size(); ++i) lossSum = sum(lossSum, axes_[i]); @@ -311,17 +303,18 @@ protected: } public: - LabelwiseLoss(const std::vector<int>& axes) - : axes_(axes) { } + LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {} - virtual RationalLoss apply(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) { + virtual RationalLoss apply(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) { Expr loss = compute(logits, labels, mask, labelWeights); if(mask) - return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting + return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting else - return reduce(loss); // we have no mask, assume all items are labels + return reduce(loss); // we have no mask, assume all items are labels } }; @@ -331,28 +324,34 @@ public: class CrossEntropyLoss : public LabelwiseLoss { public: CrossEntropyLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight) - : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 - labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {} + : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1 + labelSmoothing_(labelSmoothing), + factorWeight_(factorWeight) {} virtual ~CrossEntropyLoss() {} -protected: - float labelSmoothing_; // interpolation factor for label smoothing, see below - float factorWeight_; // give extra weight to factors - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up +protected: + float labelSmoothing_; // interpolation factor for label smoothing, see below + float factorWeight_; // give extra weight to factors + + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + // logits may be factored; in that case, the getLoss() function computes one loss for each, and + // sums them up int inFactor = false; auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { - logits = atleast_3d(logits); // we always assume a time and batch dimension exists. + logits = atleast_3d(logits); // we always assume a time and batch dimension exists. // for bert training or classification the time dimension is lost. // Here safeguard against 2d classifier output, adds 1 on the left, non-op. - + Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32); - if (inFactor && factorWeight_ != 1.0f) { + if(inFactor && factorWeight_ != 1.0f) { LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_); ce = ce * factorWeight_; } @@ -365,8 +364,10 @@ protected: if(labelWeights) { // We currently do not know how to use target factors and word-level label weights together - bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights. - ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors"); + bool wordlevel = labelWeights->shape()[-3] + > 1; // Time-dimension is not trivially 1, hence we have word-level weights. + ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, + "CE loss with word-level label weights is not implemented for factors"); ce = ce * cast(labelWeights, Type::float32); } @@ -374,13 +375,12 @@ protected: } }; - /** * @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an * implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319. - * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not - * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going - * to flip over to use SUL for that sentence to penalize the selected word. + * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are + * not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it + * is going to flip over to use SUL for that sentence to penalize the selected word. * * SUL is implemented as: * -log(gather(1 - softmax(logits), -1, indices)) @@ -390,35 +390,45 @@ protected: class SequenceUnlikelihoodLoss : public CrossEntropyLoss { public: SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight) - : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1 + : CrossEntropyLoss(labelSmoothing, factorWeight) { + } // cross-entropy already reduces over axis -1 SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight) - : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} + : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {} protected: - virtual Expr compute(Logits logits, const Words& labels, - Expr mask = nullptr, Expr labelWeights = nullptr) override { - auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE + virtual Expr compute(Logits logits, + const Words& labels, + Expr mask = nullptr, + Expr labelWeights = nullptr) override { + auto ce = CrossEntropyLoss::compute( + logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE if(!labelWeights) - return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? + return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)? // We currently do not know how to use target factors and word-level label weights together ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors"); - ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete. - // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete) + ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by + // default 1, which would make this obsolete. + // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask + // again to eliminate padding (might be obsolete) auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32); auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) { return cast(unlikelihood(logits, indices), Type::float32); }); - // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training - // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL. - auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present - ceUl = errorMask * ceUl; // don't use for correct label or padding + // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only + // on_ the errors with UL. This is the "mixed" training schedule from + // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily + // switch between CE and UL. + auto onlyCe = eq(sum(errorMask, /*axis=*/-3), + 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present + ceUl = errorMask * ceUl; // don't use for correct label or padding - auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry + auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never + // simultanously used as cost per batch entry return cost; } @@ -463,7 +473,6 @@ public: } }; - /** * @brief Factory for label-wise loss functions */ diff --git a/src/layers/output.cpp b/src/layers/output.cpp new file mode 100644 index 00000000..1d9c7b4b --- /dev/null +++ b/src/layers/output.cpp @@ -0,0 +1,293 @@ +#include "output.h" +#include "common/timer.h" +#include "data/factored_vocab.h" +#include "layers/loss.h" +#include "layers/lsh.h" + +namespace marian { +namespace mlp { + +/*private*/ void Output::lazyConstruct(int inputDim) { + // We must construct lazily since we won't know tying nor input dim in constructor. + if(Wt_) + return; + + // this option is only set in the decoder + if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { + auto k = opt<std::vector<int>>("output-approx-knn")[0]; + auto nbits = opt<std::vector<int>>("output-approx-knn")[1]; + lsh_ = New<LSH>(k, nbits); + } + + auto name = options_->get<std::string>("prefix"); + auto numOutputClasses = options_->get<int>("dim"); + + factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", "")); + if(factoredVocab_) { + numOutputClasses = (int)factoredVocab_->factorVocabSize(); + LOG_ONCE(info, "[embedding] Factored outputs enabled"); + } + + if(tiedParam_) { + Wt_ = tiedParam_; + } else { + if(graph_->get(name + "_W")) { // support of legacy models that did not transpose + Wt_ = graph_->param( + name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false)); + isLegacyUntransposedW = true; + } else // this is the regular case: + Wt_ = graph_->param( + name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true)); + } + + if(hasBias_) + b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros()); + + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary"); + if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix +#define HARDMAX_HACK +#ifdef HARDMAX_HACK + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number +#endif + auto range = factoredVocab_->getGroupRange(0); + auto lemmaVocabDim = (int)(range.second - range.first); + auto initFunc = inits::glorotUniform( + /*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length + lemmaEt_ = graph_->param(name + "_lemmaEt", + {lemmaDimEmb, lemmaVocabDim}, + initFunc); // [L x U] L=lemmaDimEmb; transposed for speed + } +} + +Logits Output::applyAsLogits(Expr input) /*override final*/ { + lazyConstruct(input->shape()[-1]); + + auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + if(b) + return affine(x, W, b, transA, transB); + else + return dot(x, W, transA, transB); + }; + + auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) { + if(lsh_) { + ABORT_IF(transA, "Transposed query not supported for LSH"); + ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); + return lsh_->apply(x, W, b); // knows how to deal with undefined bias + } else { + return affineOrDot(x, W, b, transA, transB); + } + }; + + if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one + // batch, then clear()ed + cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); + if(hasBias_) + cachedShortb_ = index_select(b_, -1, shortlist_->indices()); + } + + if(factoredVocab_) { + auto graph = input->graph(); + + // project each factor separately + auto numGroups = factoredVocab_->getNumGroups(); + std::vector<Ptr<RationalLoss>> allLogits(numGroups, + nullptr); // (note: null entries for absent factors) + Expr input1 = input; // [B... x D] + Expr Plemma = nullptr; // used for lemmaDimEmb=-1 + Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3 + for(size_t g = 0; g < numGroups; g++) { + auto range = factoredVocab_->getGroupRange(g); + if(g > 0 && range.first == range.second) // empty entry + continue; + ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second, + "Factor groups must be consecutive (group {} vs predecessor)", + g); + // slice this group's section out of W_ + Expr factorWt, factorB; + if(g == 0 && shortlist_) { + factorWt = cachedShortWt_; + factorB = cachedShortb_; + } else { + factorWt = slice( + Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second)); + if(hasBias_) + factorB = slice(b_, -1, Slice((int)range.first, (int)range.second)); + } + /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0); + if((lemmaDimEmb == -2 || lemmaDimEmb == -3) + && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max) + LOG_ONCE(info, "[embedding] using lemma conditioning with gate"); + // this mimics one transformer layer + // - attention over two inputs: + // - e = current lemma. We use the original embedding vector; specifically, expectation + // over all lemmas. + // - input = hidden state FF(h_enc+h_dec) + // - dot-prod attention to allow both sides to influence (unlike our recurrent + // self-attention) + // - multi-head to allow for multiple conditions to be modeled + // - add & norm, for gradient flow and scaling + // - FF layer --this is expensive; it is per-factor + // multi-head attention + int inputDim = input->shape()[-1]; + int heads = 8; + auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g); + auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform()); + auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform()); + auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform()); + auto toMultiHead = [&](Expr x, int heads) { + const auto& shape = x->shape(); + int inputDim = shape[-1]; + int otherDim = shape.elements() / inputDim; + ABORT_IF(inputDim / heads * heads != inputDim, + "inputDim ({}) must be multiple of number of heads ({})", + inputDim, + heads); + return reshape(x, {otherDim, heads, 1, inputDim / heads}); + }; + input1 = inputLemma; + auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query + auto kdm = toMultiHead(dot(input1 - input, Wk), + heads); // [B... x H x D/H] the two data vectors projected as keys. + // Use diff and sigmoid, instead of softmax. + auto vem = toMultiHead( + dot(input1, Wv), + heads); // [B... x H x D/H] one of the two data vectors projected as values + auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other + auto zm = bdot(qm, kdm, false, true); // [B... x H x 1] + auto sm = sigmoid(zm); // [B... x H x 1] + auto rm = sm * (vem - vim) + vim; // [B... x H x D/H] + auto r = reshape(rm, input->shape()); // [B... x D] + // add & norm + input1 = r + input1; + input1 = layerNorm(input1, name + "_att"); + // FF layer + auto ffnDropProb = 0.1f; // @TODO: get as a parameter + auto ffnDim = inputDim * 2; // @TODO: get as a parameter + auto f = denseInline(input1, + name + "_ffn", + /*suffix=*/"1", + ffnDim, + inits::glorotUniform(), + (ActivationFunction*)relu, + ffnDropProb); + f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim); + // add & norm + input1 = f + input1; + input1 = layerNorm(input1, name + "_ffn"); + } + // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a + // matrix + Expr factorLogits; + if(g == 0) + factorLogits = affineOrLSH( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + else + factorLogits = affineOrDot( + input1, + factorWt, + factorB, + false, + /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + + // optionally add lemma-dependent bias + if(Plemma) { // [B... x U0] + int lemmaVocabDim = Plemma->shape()[-1]; + int factorVocabDim = factorLogits->shape()[-1]; + auto name = options_->get<std::string>("prefix"); + Expr lemmaBt + = graph_->param(name + "_lemmaBt_" + std::to_string(g), + {factorVocabDim, lemmaVocabDim}, + inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma + auto b = dot(Plemma, lemmaBt, false, true); // [B... x U] + factorLogits = factorLogits + b; + } + allLogits[g] = New<RationalLoss>(factorLogits, nullptr); + // optionally add a soft embedding of lemma back to create some lemma dependency + // @TODO: if this works, move it into lazyConstruct + if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure + LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version"); + // get expected lemma embedding vector + auto factorLogSoftmax = logsoftmax( + factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorSoftmax = exp(factorLogSoftmax); + inputLemma = dot(factorSoftmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max + LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version"); + // get max-lemma embedding vector + auto maxVal = max(factorLogits, + -1); // [B... x U] note: with shortlist, this is not the full lemma set + auto factorHardmax = eq(factorLogits, maxVal); + inputLemma = dot(factorHardmax, + factorWt, + false, + /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D] + } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias + ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented"); + LOG_ONCE(info, "[embedding] using lemma-dependent bias"); + auto factorLogSoftmax + = logsoftmax(factorLogits); // (we do that again later, CSE will kick in) + auto z = /*stopGradient*/ (factorLogSoftmax); + Plemma = exp(z); // [B... x U] + } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix + LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb); + // compute softmax. We compute logsoftmax() separately because this way, computation will be + // reused later via CSE + auto factorLogSoftmax = logsoftmax(factorLogits); + auto factorSoftmax = exp(factorLogSoftmax); +#ifdef HARDMAX_HACK + bool hardmax = (lemmaDimEmb & 1) + != 0; // odd value triggers hardmax for now (for quick experimentation) + if(hardmax) { + lemmaDimEmb = lemmaDimEmb & 0xfffffffe; + LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb); + auto maxVal = max(factorSoftmax, -1); + factorSoftmax = eq(factorSoftmax, maxVal); + } +#endif + // re-embedding lookup, soft-indexed by softmax + if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix + cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices()); + auto e = dot(factorSoftmax, + cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, + false, + true); // [B... x L] + // project it back to regular hidden dim + int inputDim = input1->shape()[-1]; + auto name = options_->get<std::string>("prefix"); + // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also + // length 1 + Expr lemmaWt + = inputDim == lemmaDimEmb + ? nullptr + : graph_->param(name + "_lemmaWt", + {inputDim, lemmaDimEmb}, + inits::glorotUniform()); // [D x L] D=hidden-vector dimension + auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D] + // augment the original hidden vector with this additional information + input1 = input1 + f; + } + } + return Logits(std::move(allLogits), factoredVocab_); + } else if(shortlist_) { + return Logits(affineOrLSH(input, + cachedShortWt_, + cachedShortb_, + false, + /*transB=*/isLegacyUntransposedW ? false : true)); + } else { + return Logits( + affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } +} + +} // namespace mlp +} // namespace marian
\ No newline at end of file diff --git a/src/layers/output.h b/src/layers/output.h new file mode 100644 index 00000000..2b6f4986 --- /dev/null +++ b/src/layers/output.h @@ -0,0 +1,75 @@ +#pragma once + +#include "data/shortlist.h" +#include "generic.h" +#include "layers/factory.h" +#include "logits.h" +#include "marian.h" + +namespace marian { +class LSH; + +namespace mlp { + +class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { +private: + // parameters held by this layer + Expr Wt_; // weight matrix is stored transposed for efficiency + Expr b_; + Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize] + bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form + bool hasBias_{true}; + + Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) + Expr cachedShortb_; // these match the current value of shortlist_ + Expr cachedShortLemmaEt_; + Ptr<FactoredVocab> factoredVocab_; + + // optional parameters set/updated after construction + Expr tiedParam_; + Ptr<data::Shortlist> shortlist_; + Ptr<LSH> lsh_; + + void lazyConstruct(int inputDim); + +public: + Output(Ptr<ExpressionGraph> graph, Ptr<Options> options) + : LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} { + clear(); + } + + void tieTransposed(Expr tied) { + if(Wt_) + ABORT_IF(tiedParam_.get() != tied.get(), + "Tied output projection cannot be changed once weights have been created"); + else + tiedParam_ = tied; + } + + void setShortlist(Ptr<data::Shortlist> shortlist) override final { + if(shortlist_) + ABORT_IF(shortlist.get() != shortlist_.get(), + "Output shortlist cannot be changed except after clear()"); + else { + ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, + "No shortlist but cached parameters??"); + shortlist_ = shortlist; + } + // cachedShortWt_ and cachedShortb_ will be created lazily inside apply() + } + + // this is expected to be called in sync with graph->clear(), which invalidates + // cachedShortWt_ etc. in the graph's short-term cache + void clear() override final { + shortlist_ = nullptr; + cachedShortWt_ = nullptr; + cachedShortb_ = nullptr; + cachedShortLemmaEt_ = nullptr; + } + + Logits applyAsLogits(Expr input) override final; +}; + +} // namespace mlp + +} // namespace marian diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 6476df8f..6476df8f 100755..100644 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h index 87de1948..87de1948 100755..100644 --- a/src/microsoft/quicksand.h +++ b/src/microsoft/quicksand.h diff --git a/src/microsoft/shortlist/logging/LoggerMacros.h b/src/microsoft/shortlist/logging/LoggerMacros.h new file mode 100644 index 00000000..ca74e737 --- /dev/null +++ b/src/microsoft/shortlist/logging/LoggerMacros.h @@ -0,0 +1,25 @@ +#pragma once + +// Do NOT include this file directly except in special circumstances. +// (E.g., you want to define macros which call these but don't want to include Logger.h everywhere). +// Normally you should include logging/Logger.h + +#define LOG_WRITE(format, ...) do {\ + abort(); \ +} while (0) + +#define LOG_WRITE_STRING(str) do {\ + abort(); \ +} while (0) + +#define LOG_ERROR(format, ...) do {\ + abort(); \ +} while (0) + +#define LOG_ERROR_AND_THROW(format, ...) do {\ + abort(); \ +} while (0) + +#define DECODING_LOGIC_ERROR(format, ...) do {\ + abort(); \ +} while (0) diff --git a/src/microsoft/shortlist/utils/Converter.cpp b/src/microsoft/shortlist/utils/Converter.cpp new file mode 100644 index 00000000..c28178cd --- /dev/null +++ b/src/microsoft/shortlist/utils/Converter.cpp @@ -0,0 +1,59 @@ +#include "microsoft/shortlist/utils/Converter.h" + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + + +int64_t Converter::ToInt64(const std::string& str) { + return ConvertSingleInternal<int64_t>(str, "int64_t"); +} + +uint64_t Converter::ToUInt64(const std::string& str) { + return ConvertSingleInternal<uint64_t>(str, "int64_t"); +} + +int32_t Converter::ToInt32(const std::string& str) { + return ConvertSingleInternal<int32_t>(str, "int32_t"); +} + +float Converter::ToFloat(const std::string& str) { + // In case the value is out of range of a 32-bit float, but in range of a 64-bit double, + // it's better to convert as a double and then do the conersion. + return (float)ConvertSingleInternal<double>(str, "float"); +} + +double Converter::ToDouble(const std::string& str) { + return ConvertSingleInternal<double>(str, "double"); +} + +bool Converter::ToBool(const std::string& str) { + bool value = false; + if (!TryConvert(str, /* out */ value)) { + LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type 'bool'", str.c_str()); + } + return value; +} + +std::vector<int32_t> Converter::ToInt32Vector(const std::vector<std::string>& items) { + return ConvertVectorInternal<int32_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int32_t"); +} + +std::vector<int64_t> Converter::ToInt64Vector(const std::vector<std::string>& items) { + return ConvertVectorInternal<int64_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int64_t"); +} + +std::vector<float> Converter::ToFloatVector(const std::vector<std::string>& items) { + return ConvertVectorInternal<float, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "float"); +} + +std::vector<double> Converter::ToDoubleVector(const std::vector<std::string>& items) { + return ConvertVectorInternal<double, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "double"); +} + +void Converter::HandleConversionError(const std::string& str, const char * type_name) { + str; type_name; // make compiler happy + LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type '%s'", str.c_str(), type_name); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/Converter.h b/src/microsoft/shortlist/utils/Converter.h new file mode 100644 index 00000000..9d9dd96d --- /dev/null +++ b/src/microsoft/shortlist/utils/Converter.h @@ -0,0 +1,83 @@ +#pragma once + +#include <stdint.h> +#include <string> +#include <vector> +#include <sstream> + +namespace quicksand { + +class Converter { +public: + static int32_t ToInt32(const std::string& str); + + static int64_t ToInt64(const std::string& str); + + static uint64_t ToUInt64(const std::string& str); + + static float ToFloat(const std::string& str); + + static double ToDouble(const std::string& str); + + static bool ToBool(const std::string& str); + + static std::vector<int32_t> ToInt32Vector(const std::vector<std::string>& items); + + static std::vector<int64_t> ToInt64Vector(const std::vector<std::string>& items); + + static std::vector<float> ToFloatVector(const std::vector<std::string>& items); + + static std::vector<double> ToDoubleVector(const std::vector<std::string>& items); + + static bool TryConvert(const std::string& str, /* out*/ bool& obj) { + if (str == "True" || str == "true" || str == "TRUE" || str == "Yes" || str == "yes" || str == "1") { + obj = true; + return true; + } + else if (str == "False" || str == "false" || str == "FALSE" || str == "No" || str == "no" || str == "0") { + obj = false; + return true; + } + return false; + } + + template <typename T> + static bool TryConvert(const std::string& str, /* out*/ T& value) { + std::istringstream ss(str); + value = T(); + if (!(ss >> value)) { + return false; + } + return true; + } + +private: + template <typename T> + static T ConvertSingleInternal(const std::string& str, const char * type_name); + + template <typename T, typename I> + static std::vector<T> ConvertVectorInternal(I begin, I end, const char * type_name); + + static void HandleConversionError(const std::string& str, const char * type_name); +}; + +template <typename T> +T Converter::ConvertSingleInternal(const std::string& str, const char * type_name) { + std::istringstream ss(str); + T value = T(); + if (!(ss >> value)) { + HandleConversionError(str, type_name); + } + return value; +} + +template <typename T, typename I> +std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * type_name) { + std::vector<T> items; + for (I it = begin; it != end; it++) { + items.push_back(ConvertSingleInternal<T>(*it, type_name)); + } + return items; +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/ParameterTree.cpp b/src/microsoft/shortlist/utils/ParameterTree.cpp new file mode 100644 index 00000000..465d2e0d --- /dev/null +++ b/src/microsoft/shortlist/utils/ParameterTree.cpp @@ -0,0 +1,417 @@ +#include "microsoft/shortlist/utils/ParameterTree.h" + +#include <string> + +#include "microsoft/shortlist/utils/StringUtils.h" +#include "microsoft/shortlist/utils/Converter.h" + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + +std::shared_ptr<ParameterTree> ParameterTree::m_empty_tree = std::make_shared<ParameterTree>("params"); + +ParameterTree::ParameterTree() { + m_name = "root"; +} + +ParameterTree::ParameterTree(const std::string& name) { + m_name = name; +} + +ParameterTree::~ParameterTree() { +} + +void ParameterTree::Clear() { + +} + +void ParameterTree::ReplaceVariables( + const std::unordered_map<std::string, std::string>& vars, + bool error_on_unknown_vars) +{ + ReplaceVariablesInternal(vars, error_on_unknown_vars); +} + +void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) { + RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param); +} + +void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) { + RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param); +} + +void ParameterTree::RegisterFloat(const std::string& name, float * param) { + RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param); +} + +void ParameterTree::RegisterDouble(const std::string& name, double * param) { + RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param); +} + +void ParameterTree::RegisterBool(const std::string& name, bool * param) { + RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param); +} + +void ParameterTree::RegisterString(const std::string& name, std::string * param) { + RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param); +} + +std::shared_ptr<ParameterTree> ParameterTree::FromBinaryReader(const void*& current) { + std::shared_ptr<ParameterTree> root = std::make_shared<ParameterTree>(); + root->ReadBinary(current); + return root; +} + +void ParameterTree::SetRegisteredParams() { + for (std::size_t i = 0; i < m_registered_params.size(); i++) { + const RegisteredParam& rp = m_registered_params[i]; + switch (rp.Type()) { + case PARAM_TYPE_INT32: + (*(int32_t *)rp.Data()) = GetInt32Req(rp.Name()); + break; + case PARAM_TYPE_INT64: + (*(int64_t *)rp.Data()) = GetInt64Req(rp.Name()); + break; + default: + LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type()); + } + } +} + +int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToInt32(*value); +} + +int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToInt64(*value); +} + +uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToUInt64(*value); +} + +double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToDouble(*value); +} + +float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToFloat(*value); +} + +std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return (*value); +} + +bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return defaultValue; + } + return Converter::ToBool(*value); +} + +int32_t ParameterTree::GetInt32Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToInt32(value); +} + +uint64_t ParameterTree::GetUInt64Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToUInt64(value); +} + +int64_t ParameterTree::GetInt64Req(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToInt64(value); +} + +double ParameterTree::GetDoubleReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToDouble(value); +} + +float ParameterTree::GetFloatReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToFloat(value); +} + +bool ParameterTree::GetBoolReq(const std::string& name) const { + std::string value = GetStringReq(name); + return Converter::ToBool(value); +} + +std::string ParameterTree::GetStringReq(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str()); + } + return (*value); +} + +std::vector<std::string> ParameterTree::GetFileListReq(const std::string& name) const { + std::vector<std::string> output = GetFileListOptional(name); + if (output.size() == 0) { + LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str()); + } + return output; +} + +std::vector<std::string> ParameterTree::GetFileListOptional(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr || (*value).size() == 0) { + return std::vector<std::string>(); + } + std::vector<std::string> all_files = StringUtils::Split(*value, ";"); + return all_files; +} + +std::vector<std::string> ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const { + std::string value = GetStringReq(name); + std::vector<std::string> output = StringUtils::Split(value, sep); + return output; +} + +std::vector<std::string> ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const { + std::string value = GetStringOr(name, ""); + std::vector<std::string> output = StringUtils::Split(value, sep); + return output; +} + +std::shared_ptr<ParameterTree> ParameterTree::GetChildReq(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return child; + } + } + LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str()); + return nullptr; // never happens +} + + +std::shared_ptr<ParameterTree> ParameterTree::GetChildOrEmpty(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return child; + } + } + return std::make_shared<ParameterTree>(); +} + +// cast current void pointer to T pointer and move forward by num elements +template <typename T> +const T* get(const void*& current, size_t num = 1) { + const T* ptr = (const T*)current; + current = (const T*)current + num; + return ptr; +} + +void ParameterTree::ReadBinary(const void*& current) { + auto nameLength = *get<int32_t>(current); + auto nameBytes = get<char>(current, nameLength); + m_name = std::string(nameBytes, nameBytes + nameLength); + + auto textLength = *get<int32_t>(current); + auto textBytes = get<char>(current, textLength); + m_text = std::string(textBytes, textBytes + textLength); + + int32_t num_children = *get<int32_t>(current); + m_children.resize(num_children); + for (int32_t i = 0; i < num_children; i++) { + m_children[i].reset(new ParameterTree()); + m_children[i]->ReadBinary(current); + } +} + +std::vector< std::shared_ptr<ParameterTree> > ParameterTree::GetChildren(const std::string& name) const { + std::vector< std::shared_ptr<ParameterTree> > children; + for (std::shared_ptr<ParameterTree> child : m_children) { + if (child->Name() == name) { + children.push_back(child); + } + } + return children; +} + +void ParameterTree::AddParam(const std::string& name, const std::string& text) { + std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name); + child->SetText(text); + m_children.push_back(child); +} + +void ParameterTree::SetParam(const std::string& name, const std::string& text) { + for (const auto& child : m_children) { + if (child->Name() == name) { + child->SetText(text); + return; + } + } + std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name); + child->SetText(text); + m_children.push_back(child); +} + +void ParameterTree::AddChild(std::shared_ptr<ParameterTree> child) { + m_children.push_back(child); +} + +bool ParameterTree::HasParam(const std::string& name) const { + const std::string * value = GetParamInternal(name); + if (value == nullptr) { + return false; + } + return true; +} + +bool ParameterTree::HasChild(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return true; + } + } + return false; +} + +std::string ParameterTree::ToString() const { + std::ostringstream ss; + ToStringInternal(0, ss); + return ss.str(); +} + +const std::string * ParameterTree::GetParamInternal(const std::string& name) const { + for (const auto& child : m_children) { + if (child->Name() == name) { + return &(child->Text()); + } + } + return nullptr; +} + + +void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) { + if (m_registered_param_names.find(name) != m_registered_param_names.end()) { + LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str()); + } + m_registered_params.push_back(RegisteredParam(name, type, param)); + m_registered_param_names.insert(name); +} + +void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const { + for (int32_t i = 0; i < 2*depth; i++) { + ss << " "; + } + ss << "<" << m_name << ">"; + if (m_children.size() > 0) { + ss << "\n"; + for (const std::shared_ptr<ParameterTree>& child : m_children) { + child->ToStringInternal(depth+1, ss); + } + for (int32_t i = 0; i < 2 * depth; i++) { + ss << " "; + } + ss << "</" << m_name << ">\n"; + } + else { + ss << m_text << "</" << m_name << ">\n"; + } +} + +std::shared_ptr<ParameterTree> ParameterTree::Clone() const { + std::shared_ptr<ParameterTree> node = std::make_shared<ParameterTree>(m_name); + node->m_text = m_text; + for (auto& child : m_children) { + node->m_children.push_back(child->Clone()); + } + return node; +} + +void ParameterTree::Merge(const ParameterTree& other) { + m_name = other.m_name; + m_text = other.m_text; + for (auto& other_child : other.m_children) { + if (HasChild(other_child->Name())) { + auto my_child = GetChildReq(other_child->Name()); + if (other_child->Text() != "" && my_child->Text() != "") { + my_child->SetText(other_child->Text()); + } + else { + my_child->Merge(*other_child); + } + } + else { + m_children.push_back(other_child->Clone()); + } + } +} + +void ParameterTree::ReplaceVariablesInternal( + const std::unordered_map<std::string, std::string>& vars, + bool error_on_unknown_vars) +{ + std::size_t offset = 0; + std::ostringstream ss; + while (true) { + std::size_t s_pos = m_text.find("$$", offset); + if (s_pos == std::string::npos) { + break; + } + std::size_t e_pos = m_text.find("$$", s_pos + 2); + if (e_pos == std::string::npos) { + break; + } + + if (offset != s_pos) { + ss << m_text.substr(offset, s_pos-offset); + } + + std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2)); + auto it = vars.find(var_name); + if (it != vars.end()) { + std::string value = it->second; + ss << value; + } + else { + if (error_on_unknown_vars) { + LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str()); + } + else { + ss << "$$" << var_name << "$$"; + } + } + offset = e_pos + 2; + } + ss << m_text.substr(offset); + + m_text = ss.str(); + + for (auto& child : m_children) { + child->ReplaceVariablesInternal(vars, error_on_unknown_vars); + } +} + +} // namespace quicksand + diff --git a/src/microsoft/shortlist/utils/ParameterTree.h b/src/microsoft/shortlist/utils/ParameterTree.h new file mode 100644 index 00000000..1474ff64 --- /dev/null +++ b/src/microsoft/shortlist/utils/ParameterTree.h @@ -0,0 +1,185 @@ +#pragma once + +#include <string> +#include <vector> +#include <unordered_set> +#include <unordered_map> +#include <memory> + +#include "microsoft/shortlist/utils/StringUtils.h" + +namespace quicksand { + +class ParameterTree { +private: + enum ParameterType { + PARAM_TYPE_INT32, + PARAM_TYPE_INT64, + PARAM_TYPE_UINT64, + PARAM_TYPE_FLOAT, + PARAM_TYPE_DOUBLE, + PARAM_TYPE_BOOL, + PARAM_TYPE_STRING + }; + + class RegisteredParam { + private: + std::string m_name; + ParameterType m_type; + void * m_data; + + public: + RegisteredParam() {} + + RegisteredParam(const std::string& name, + ParameterType type, + void * data) + { + m_name = name; + m_type = type; + m_data = data; + } + + const std::string& Name() const {return m_name;} + const ParameterType& Type() const {return m_type;} + void * Data() const {return m_data;} + }; + + static std::shared_ptr<ParameterTree> m_empty_tree; + + std::string m_name; + + std::string m_text; + + std::vector< std::shared_ptr<ParameterTree> > m_children; + + std::unordered_set<std::string> m_registered_param_names; + + std::vector<RegisteredParam> m_registered_params; + +public: + ParameterTree(); + + ParameterTree(const std::string& name); + + ~ParameterTree(); + + inline const std::string& Text() const { return m_text; } + inline void SetText(const std::string& text) { m_text = text; } + + inline const std::string& Name() const { return m_name; } + inline void SetName(const std::string& name) { m_name = name; } + + void Clear(); + + void ReplaceVariables( + const std::unordered_map<std::string, std::string>& vars, + bool error_on_unknown_vars = true); + + void RegisterInt32(const std::string& name, int32_t * param); + + void RegisterInt64(const std::string& name, int64_t * param); + + void RegisterFloat(const std::string& name, float * param); + + void RegisterDouble(const std::string& name, double * param); + + void RegisterBool(const std::string& name, bool * param); + + void RegisterString(const std::string& name, std::string * param); + + static std::shared_ptr<ParameterTree> FromBinaryReader(const void*& current); + + void SetRegisteredParams(); + + int32_t GetInt32Req(const std::string& name) const; + + int64_t GetInt64Req(const std::string& name) const; + + uint64_t GetUInt64Req(const std::string& name) const; + + double GetDoubleReq(const std::string& name) const; + + float GetFloatReq(const std::string& name) const; + + std::string GetStringReq(const std::string& name) const; + + bool GetBoolReq(const std::string& name) const; + + int32_t GetInt32Or(const std::string& name, int32_t defaultValue) const; + + int64_t GetInt64Or(const std::string& name, int64_t defaultValue) const; + + uint64_t GetUInt64Or(const std::string& name, uint64_t defaultValue) const; + + std::string GetStringOr(const std::string& name, const std::string& defaultValue) const; + + double GetDoubleOr(const std::string& name, double defaultValue) const; + + float GetFloatOr(const std::string& name, float defaultValue) const; + + bool GetBoolOr(const std::string& name, bool defaultValue) const; + + std::vector<std::string> GetFileListReq(const std::string& name) const; + + std::vector<std::string> GetFileListOptional(const std::string& name) const; + + std::vector<std::string> GetStringListReq(const std::string& name, const std::string& sep = " ") const; + + std::vector<std::string> GetStringListOptional(const std::string& name, const std::string& sep = " ") const; + + std::shared_ptr<ParameterTree> GetChildReq(const std::string& name) const; + + std::shared_ptr<ParameterTree> GetChildOrEmpty(const std::string& name) const; + + std::vector< std::shared_ptr<ParameterTree> > GetChildren(const std::string& name) const; + + inline const std::vector< std::shared_ptr<ParameterTree> >& GetChildren() const { return m_children; } + + void ReadBinary(const void*& current); + + void AddParam(const std::string& name, const std::string& text); + + template <typename T> + void AddParam(const std::string& name, const T& obj); + + void SetParam(const std::string& name, const std::string& text); + + template <typename T> + void SetParam(const std::string& name, const T& obj); + + void AddChild(std::shared_ptr<ParameterTree> child); + + std::string ToString() const; + + bool HasChild(const std::string& name) const; + + bool HasParam(const std::string& name) const; + + std::shared_ptr<ParameterTree> Clone() const; + + void Merge(const ParameterTree& other); + +private: + void ReplaceVariablesInternal( + const std::unordered_map<std::string, std::string>& vars, + bool error_on_unknown_vars); + + void RegisterItemInternal(const std::string& name, ParameterType type, void * param); + + const std::string * GetParamInternal(const std::string& name) const; + + void ToStringInternal(int32_t depth, std::ostream& ss) const; +}; + +template <typename T> +void ParameterTree::AddParam(const std::string& name, const T& obj) { + AddParam(name, StringUtils::ToString(obj)); +} + +template <typename T> +void ParameterTree::SetParam(const std::string& name, const T& obj) { + SetParam(name, StringUtils::ToString(obj)); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/PrintTypes.h b/src/microsoft/shortlist/utils/PrintTypes.h new file mode 100644 index 00000000..6bc1363d --- /dev/null +++ b/src/microsoft/shortlist/utils/PrintTypes.h @@ -0,0 +1,16 @@ +#pragma once + +#include <inttypes.h> + +#ifdef QUICKSAND_WINDOWS_BUILD +#define PI32 "d" +#define PI64 "lld" +#define PU32 "u" +#define PU64 "llu" +#else +#define PI32 PRId32 +#define PI64 PRId64 +#define PU32 PRIu32 +#define PU64 PRIu64 +#endif + diff --git a/src/microsoft/shortlist/utils/StringUtils.cpp b/src/microsoft/shortlist/utils/StringUtils.cpp new file mode 100644 index 00000000..7870b542 --- /dev/null +++ b/src/microsoft/shortlist/utils/StringUtils.cpp @@ -0,0 +1,338 @@ +#include "microsoft/shortlist/utils/StringUtils.h" + +#include <stdio.h> +#include <algorithm> +#include <string> + +namespace quicksand { + +#include "microsoft/shortlist/logging/LoggerMacros.h" + +std::string StringUtils::VarArgsToString(const char * format, va_list args) { + if (format == nullptr) { + LOG_ERROR_AND_THROW("'format' cannot be null in StringUtils::VarArgsToString"); + } + + std::string output; + // Most of the time the stack buffer (5000 chars) will be sufficient. + // In cases where this is insufficient, dynamically allocate an appropriately sized buffer + char buffer[5000]; +#ifdef QUICKSAND_WINDOWS_BUILD + va_list copy; + va_copy(copy, args); + int ret = vsnprintf_s(buffer, sizeof(buffer), _TRUNCATE, format, copy); + va_end(copy); + if (ret >= 0) { + output = std::string(buffer, buffer + ret); + } + else { + va_list copy2; + va_copy(copy2, args); + int needed_size = _vscprintf(format, copy2); + va_end(copy2); + + if (needed_size < 0) { + LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen"); + } + char * dynamic_buffer = new char[needed_size+1]; + int ret2 = vsnprintf_s(dynamic_buffer, needed_size+1, _TRUNCATE, format, args); + if (ret2 >= 0) { + output = std::string(dynamic_buffer, dynamic_buffer + ret2); + delete[] dynamic_buffer; + } + else { + output = ""; + delete[] dynamic_buffer; + LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen, " + "since we made a call to _vscprintf() to check the dynamic buffer size. The call to _vscprintf() " + "returned %d bytes, but apparently that was not enough. This would imply a bug in MSVC's vsnprintf_s implementation.", needed_size); + } + } +#else + va_list copy; + va_copy(copy, args); + int needed_size = vsnprintf(buffer, sizeof(buffer), format, copy); + va_end(copy); + if (needed_size < (int)sizeof(buffer)) { + output = std::string(buffer, buffer + needed_size); + } + else { + char * dynamic_buffer = new char[needed_size+1]; + int ret = vsnprintf(dynamic_buffer, needed_size + 1, format, args); + if (ret >= 0 && ret < needed_size + 1) { + output = std::string(dynamic_buffer); + delete[] dynamic_buffer; + } + else { + output = ""; + delete[] dynamic_buffer; + LOG_ERROR_AND_THROW("A call to vsnprintf() failed. Return value: %d.", + ret); + } + } +#endif + return output; +} + +std::vector<std::string> StringUtils::SplitIntoLines(const std::string& input) { + std::vector<std::string> output; + if (input.size() == 0) { + return output; + } + std::size_t start = 0; + for (std::size_t i = 0; i < input.size(); i++) { + char c = input[i]; + if (c == '\r' || c == '\n') { + output.push_back(std::string(input.begin() + start, input.begin() + i)); + start = i+1; + } + if (c == '\r' && i + 1 < input.size() && input[i+1] == '\n') { + i++; + start = i+1; + } + } + // do NOT put an empty length trailing line (but empty length intermediate lines are fine) + if (input.begin() + start != input.end()) { + output.push_back(std::string(input.begin() + start, input.end())); + } + return output; +} + +bool StringUtils::StartsWith(const std::string& str, const std::string& prefix) { + if (str.length() < prefix.length()) + return false; + + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool StringUtils::EndsWith(const std::string& str, const std::string& suffix) { + if (str.length() < suffix.length()) + return false; + + return std::equal(suffix.begin(), suffix.end(), str.end() - suffix.length()); +} + +std::vector<std::string> StringUtils::SplitFileList(const std::string& input) { + std::vector<std::string> output; + for (const std::string& s : SplitIntoLines(input)) { + for (const std::string& t : Split(s, ";")) { + std::string f = CleanupWhitespace(t); + output.push_back(f); + } + } + return output; +} + +std::vector<std::string> StringUtils::Split(const std::string& input, char splitter) { + std::vector<std::string> output; + if (input.size() == 0) { + return output; + } + std::size_t start = 0; + for (std::size_t i = 0; i < input.size(); i++) { + if (input[i] == splitter) { + output.push_back(std::string(input.begin() + start, input.begin() + i)); + start = i+1; + } + } + output.push_back(std::string(input.begin() + start, input.end())); + return output; +} + +std::vector<std::string> StringUtils::Split(const std::string& input, const std::string& splitter) { + std::vector<std::string> output; + if (input.size() == 0) { + return output; + } + std::size_t pos = 0; + while (true) { + std::size_t next_pos = input.find(splitter, pos); + if (next_pos == std::string::npos) { + output.push_back(std::string(input.begin() + pos, input.end())); + break; + } + else { + output.push_back(std::string(input.begin() + pos, input.begin() + next_pos)); + } + pos = next_pos + splitter.size(); + } + return output; +} + +std::string StringUtils::Join(const std::string& joiner, const uint8_t * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << (int32_t)(items[i]); + } + return ss.str(); +} + +std::string StringUtils::Join(const std::string& joiner, const int8_t * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << (int32_t)(items[i]); + } + return ss.str(); +} + +std::string StringUtils::PrintString(const char * format, ...) { + va_list args; + va_start(args, format); + std::string output = StringUtils::VarArgsToString(format, args); + va_end(args); + + return output; +} + +std::vector<std::string> StringUtils::WhitespaceTokenize(const std::string& input) { + std::vector<std::string> output; + if (input.size() == 0) { + return output; + } + std::size_t size = input.size(); + std::size_t start = 0; + std::size_t end = size; + for (std::size_t i = 0; i < size; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + start++; + } + else { + break; + } + } + for (std::size_t i = 0; i < size; i++) { + char c = input[size-1-i]; + if (IsWhitespace(c)) { + end--; + } + else { + break; + } + } + if (end <= start) { + return output; + } + bool prev_is_whitespace = false; + std::size_t token_start = start; + for (std::size_t i = start; i < end; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + if (!prev_is_whitespace) { + output.push_back(std::string(input.begin() + token_start, input.begin() + i)); + } + prev_is_whitespace = true; + token_start = i+1; + } + else { + prev_is_whitespace = false; + } + } + output.push_back(std::string(input.begin() + token_start, input.begin() + end)); + return output; +} + +std::string StringUtils::CleanupWhitespace(const std::string& input) { + if (input.size() == 0) { + return std::string(""); + } + std::size_t size = input.size(); + std::size_t start = 0; + std::size_t end = size; + for (std::size_t i = 0; i < size; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + start++; + } + else { + break; + } + } + for (std::size_t i = 0; i < size; i++) { + char c = input[size-1-i]; + if (IsWhitespace(c)) { + end--; + } + else { + break; + } + } + if (end <= start) { + return std::string(""); + } + std::ostringstream ss; + bool prev_is_whitespace = false; + for (std::size_t i = start; i < end; i++) { + char c = input[i]; + if (IsWhitespace(c)) { + if (!prev_is_whitespace) { + ss << ' '; + } + prev_is_whitespace = true; + } + else { + ss << c; + prev_is_whitespace = false; + } + } + return ss.str(); +} + +std::string StringUtils::XmlEscape(const std::string& str) { + std::ostringstream ss; + for (std::size_t i = 0; i < str.size(); i++) { + char c = str[i]; + if (c == '&') { + ss << "&"; + } + else if (c == '"') { + ss << """; + } + else if (c == '\'') { + ss << "'"; + } + else if (c == '<') { + ss << "<"; + } + else if (c == '>') { + ss << ">"; + } + else { + ss << c; + } + } + return ss.str(); +} + +std::string StringUtils::ToString(const std::string& str) { + return str; +} + +std::string StringUtils::ToString(bool obj) { + return (obj)?"true":"false"; +} + +std::string StringUtils::ToUpper(const std::string& str) { + std::vector<char> output; + output.reserve(str.size()); + for (char c : str) { + output.push_back((char)toupper((int)c)); + } + return std::string(output.begin(), output.end()); +} + +std::string StringUtils::ToLower(const std::string& str) { + std::ostringstream ss; + for (char c : str) { + ss << c; + } + return ss.str(); +} + +} // namespace quicksand diff --git a/src/microsoft/shortlist/utils/StringUtils.h b/src/microsoft/shortlist/utils/StringUtils.h new file mode 100644 index 00000000..31bb1fcc --- /dev/null +++ b/src/microsoft/shortlist/utils/StringUtils.h @@ -0,0 +1,98 @@ +#pragma once + +#include <string> +#include <sstream> +#include <stdarg.h> +#include <vector> +#include <stdint.h> + +#include "microsoft/shortlist/utils/PrintTypes.h" + +namespace quicksand { + +class StringUtils { +public: + template <typename T> + static std::string Join(const std::string& joiner, const T& items); + + template <typename T> + static std::string Join(const std::string& joiner, const T * items, int32_t length); + + static std::string Join(const std::string& joiner, const uint8_t * items, int32_t length); + + static std::string Join(const std::string& joiner, const int8_t * items, int32_t length); + + static std::vector<std::string> Split(const std::string& input, char splitter); + + static std::vector<std::string> Split(const std::string& input, const std::string& splitter); + + static std::vector<std::string> SplitFileList(const std::string& input); + + static std::string PrintString(const char * format, ...); + + static std::string VarArgsToString(const char * format, va_list args); + + static std::vector<std::string> WhitespaceTokenize(const std::string& input); + + static std::string CleanupWhitespace(const std::string& input); + + static std::string ToString(const std::string& str); + + static std::string ToString(bool obj); + + template <typename T> + static std::string ToString(const T& obj); + + static std::string XmlEscape(const std::string& str); + + static std::vector<std::string> SplitIntoLines(const std::string& input); + + static bool StartsWith(const std::string& str, const std::string& prefix); + + static bool EndsWith(const std::string& str, const std::string& suffix); + + inline static bool IsWhitespace(char c) { + return (c == ' ' || c == '\t' || c == '\n' || c == '\r'); + } + + // This should only be used for ASCII, e.g., filenames, NOT for language data + static std::string ToLower(const std::string& str); + + // This should only be used for ASCII, e.g., filenames, NOT for language data + static std::string ToUpper(const std::string& str); +}; + +template <typename T> +std::string StringUtils::Join(const std::string& joiner, const T& items) { + std::ostringstream ss; + bool first = true; + for (auto it = items.begin(); it != items.end(); it++) { + if (!first) { + ss << joiner; + } + ss << (*it); + first = false; + } + return ss.str(); +} + +template <typename T> +std::string StringUtils::Join(const std::string& joiner, const T * items, int32_t length) { + std::ostringstream ss; + for (int32_t i = 0; i < length; i++) { + if (i != 0) { + ss << joiner; + } + ss << items[i]; + } + return ss.str(); +} + +template <typename T> +std::string StringUtils::ToString(const T& obj) { + std::ostringstream ss; + ss << obj; + return ss.str(); +} + +} // namespace quicksand diff --git a/src/models/amun.h b/src/models/amun.h index 1bfda269..1bfda269 100755..100644 --- a/src/models/amun.h +++ b/src/models/amun.h diff --git a/src/models/bert.h b/src/models/bert.h index 51427457..51427457 100755..100644 --- a/src/models/bert.h +++ b/src/models/bert.h diff --git a/src/models/char_s2s.h b/src/models/char_s2s.h index 3b9bb2fa..3b9bb2fa 100755..100644 --- a/src/models/char_s2s.h +++ b/src/models/char_s2s.h diff --git a/src/models/classifier.h b/src/models/classifier.h index 9faa907e..9faa907e 100755..100644 --- a/src/models/classifier.h +++ b/src/models/classifier.h diff --git a/src/models/costs.cpp b/src/models/costs.cpp new file mode 100644 index 00000000..c688b211 --- /dev/null +++ b/src/models/costs.cpp @@ -0,0 +1,14 @@ +#include "costs.h" + +namespace marian { +namespace models { + +Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) { + // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) + state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); + // @TODO: This is becoming more and more opaque ^^. Can we simplify this? + return state; +} + +} // namespace models +} // namespace marian diff --git a/src/models/costs.h b/src/models/costs.h index 3d8f2c51..e5463bfd 100755..100644 --- a/src/models/costs.h +++ b/src/models/costs.h @@ -4,8 +4,8 @@ #include "layers/guided_alignment.h" #include "layers/loss.h" #include "layers/weight.h" -#include "models/encoder_decoder.h" #include "models/encoder_classifier.h" +#include "models/encoder_decoder.h" #include "models/encoder_pooler.h" namespace marian { @@ -22,10 +22,12 @@ namespace models { class ICost { public: - virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model? - Ptr<data::Batch> batch, - bool clearGraph = true) = 0; + virtual Ptr<MultiRationalLoss> apply( + Ptr<IModel> model, + Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model? + Ptr<data::Batch> batch, + bool clearGraph = true) + = 0; virtual ~ICost() {} }; @@ -45,10 +47,9 @@ public: : options_(options), inference_(options->get<bool>("inference", false)) { loss_ = newLoss(options_, inference_); - toBeWeighted_ - = (options_->hasAndNotEmpty("data-weighting") && !inference_) - || (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting") - && !inference_); + toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_) + || (options_->has("dynamic-weighting") + && options_->get<bool>("dynamic-weighting") && !inference_); if(toBeWeighted_) weighter_ = WeightingFactory(options_); } @@ -56,9 +57,9 @@ public: virtual ~EncoderDecoderCECost() {} Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) override { auto encdec = std::static_pointer_cast<EncoderDecoder>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); @@ -72,17 +73,17 @@ public: Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_); // @TODO: adapt to multi-objective training with multiple decoders - auto partialLoss = loss_->apply(state->getLogProbs(), - state->getTargetWords(), - state->getTargetMask(), - weights); + auto partialLoss = loss_->apply( + state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights); multiLoss->push_back(partialLoss); if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) { - auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] + auto attentionVectors + = encdec->getDecoders()[0] + ->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1] ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments"); - auto attention = concatenate(attentionVectors, /*axis =*/ -1); + auto attention = concatenate(attentionVectors, /*axis =*/-1); auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention); multiLoss->push_back(alignmentLoss); @@ -109,10 +110,9 @@ public: } Ptr<MultiRationalLoss> apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { - + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) override { auto enccls = std::static_pointer_cast<EncoderClassifier>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); @@ -141,21 +141,20 @@ protected: public: EncoderPoolerRankCost(Ptr<Options> options) - : options_(options), - inference_(options->get<bool>("inference", false)) { - auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {}); - ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set"); - - margin_ = std::stof(trainEmbedderRank[0]); - if(trainEmbedderRank.size() > 1) - normalizer_ = std::stof(trainEmbedderRank[1]); + : options_(options), inference_(options->get<bool>("inference", false)) { + auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {}); + ABORT_IF(trainEmbedderRank.empty(), + "EncoderPoolerRankCost expects train-embedder-rank to be set"); + + margin_ = std::stof(trainEmbedderRank[0]); + if(trainEmbedderRank.size() > 1) + normalizer_ = std::stof(trainEmbedderRank[1]); } Ptr<MultiRationalLoss> apply(Ptr<IModel> model, Ptr<ExpressionGraph> graph, Ptr<data::Batch> batch, bool clearGraph = true) override { - auto encpool = std::static_pointer_cast<EncoderPooler>(model); auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch); std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph); @@ -167,28 +166,41 @@ public: ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss"); // multi-objective training - auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability - auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product + auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability + auto exponent + = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product auto dp = exp(exponent); Expr dn1, dn2; - if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator. - dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + if(normalizer_ + != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the + // magnitude of the sum of negative examples in the denominator. + dn1 = normalizer_ + * mean(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = normalizer_ + * mean(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } else { - dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example - dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example + dn1 = sum(exp(dotProducts[1] - maxDot), + -1); // dot product of anchor and first negative example + dn2 = sum(exp(dotProducts[2] - maxDot), + -1); // dot product of positive examples and first negative example } // We rewrite the loss so it looks more like a log-softmax, presumably more stable? - // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) - auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples - auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples - auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); - + // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + // + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m) + auto marginLoss1 + = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples + auto marginLoss2 + = log(dp + dn2) + - exponent; // symmetric version of the above with positive example vs negative examples + auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2); + RationalLoss loss(marginLoss, (float)dimBatch); multiLoss->push_back(loss); - + return multiLoss; } }; @@ -199,8 +211,7 @@ protected: Ptr<ICost> cost_; public: - Trainer(Ptr<IModel> model, Ptr<ICost> cost) - : model_(model), cost_(cost) {} + Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {} virtual ~Trainer() {} @@ -219,8 +230,8 @@ public: } virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<data::Batch> batch, + bool clearGraph = true) override { return cost_->apply(model_, graph, batch, clearGraph); }; @@ -230,24 +241,25 @@ public: class ILogProb { public: virtual Logits apply(Ptr<IModel> model, - Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) = 0; + Ptr<ExpressionGraph> graph, + Ptr<data::Batch> batch, + bool clearGraph = true) + = 0; }; -// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth? -// Beam search uses it for the former meaning, while 'marian score' and validation in the latter. -// This class is for the former use. The latter is done using Trainer. +// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for +// the ground truth? +// Beam search uses it for the former meaning, while 'marian score' and validation in the +// latter. This class is for the former use. The latter is done using Trainer. class Scorer : public IModel { protected: Ptr<IModel> model_; Ptr<ILogProb> logProb_; public: - Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) - : model_(model), logProb_(cost) {} + Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {} - virtual ~Scorer(){} + virtual ~Scorer() {} Ptr<IModel> getModel() { return model_; } @@ -264,8 +276,8 @@ public: } virtual Logits build(Ptr<ExpressionGraph> graph, - Ptr<data::Batch> batch, - bool clearGraph = true) override { + Ptr<data::Batch> batch, + bool clearGraph = true) override { return logProb_->apply(model_, graph, batch, clearGraph); }; @@ -282,12 +294,7 @@ public: class LogSoftmaxStep : public ILogProbStep { public: virtual ~LogSoftmaxStep() {} - virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override { - // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost) - state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax)); - // @TODO: This is becoming more and more opaque ^^. Can we simplify this? - return state; - } + virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override; }; // Gumbel-max noising for sampling during beam-search @@ -298,10 +305,10 @@ public: virtual ~GumbelSoftmaxStep() {} virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override { state->setLogProbs(state->getLogProbs().applyUnaryFunctions( - [](Expr logits){ // lemma gets gumbelled - return logsoftmax(logits + constant_like(logits, inits::gumbel())); - }, - logsoftmax)); // factors don't + [](Expr logits) { // lemma gets gumbelled + return logsoftmax(logits + constant_like(logits, inits::gumbel())); + }, + logsoftmax)); // factors don't return state; } }; @@ -316,8 +323,7 @@ protected: Ptr<ILogProbStep> cost_; public: - Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) - : encdec_(encdec), cost_(cost) {} + Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {} virtual void load(Ptr<ExpressionGraph> graph, const std::string& name, @@ -351,12 +357,13 @@ public: return encdec_->startState(graph, batch); } - virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph, - Ptr<DecoderState> state, - const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const Words& words, // [beamIndex * activeBatchSize + batchIndex] - const std::vector<IndexType>& batchIndices, // [batchIndex] - int beamSize) override { + virtual Ptr<DecoderState> step( + Ptr<ExpressionGraph> graph, + Ptr<DecoderState> state, + const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const Words& words, // [beamIndex * activeBatchSize + batchIndex] + const std::vector<IndexType>& batchIndices, // [batchIndex] + int beamSize) override { auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize); return cost_->apply(nextState); } @@ -374,9 +381,7 @@ public: encdec_->setShortlistGenerator(shortlistGenerator); }; - virtual Ptr<data::Shortlist> getShortlist() override { - return encdec_->getShortlist(); - }; + virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); }; virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); } }; diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 8fc9321a..8fc9321a 100755..100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h index 92c1647f..92c1647f 100755..100644 --- a/src/models/encoder_decoder.h +++ b/src/models/encoder_decoder.h diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp index e176e6a4..e176e6a4 100755..100644 --- a/src/models/model_factory.cpp +++ b/src/models/model_factory.cpp diff --git a/src/models/model_factory.h b/src/models/model_factory.h index 5403b966..5403b966 100755..100644 --- a/src/models/model_factory.h +++ b/src/models/model_factory.h diff --git a/src/models/nematus.h b/src/models/nematus.h index 730418e5..730418e5 100755..100644 --- a/src/models/nematus.h +++ b/src/models/nematus.h diff --git a/src/models/s2s.h b/src/models/s2s.h index 7009fad5..7009fad5 100755..100644 --- a/src/models/s2s.h +++ b/src/models/s2s.h diff --git a/src/models/states.h b/src/models/states.h index c2f9ee05..20dd59c9 100755..100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "marian.h" -#include "layers/generic.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian { @@ -9,7 +9,7 @@ namespace marian { class EncoderState { private: Expr context_; - Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask + Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask Ptr<data::CorpusBatch> batch_; public: @@ -19,31 +19,34 @@ public: EncoderState() {} virtual ~EncoderState() {} - virtual Expr getContext() const { return context_; } - virtual Expr getAttended() const { return context_; } - virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed + virtual Expr getContext() const { return context_; } + virtual Expr getAttended() const { return context_; } + virtual Expr getMask() const { + return mask_; + } // source batch mask; may have additional positions suppressed - virtual const Words& getSourceWords() { - return batch_->front()->data(); - } + virtual const Words& getSourceWords() { return batch_->front()->data(); } // Sub-select active batch entries from encoder context and context mask - Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries - // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout - return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); + Ptr<EncoderState> select( + const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries + // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer + // gets transposed to the same dimension layout + return New<EncoderState>( + index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_); } }; class DecoderState { protected: - rnn::States states_; // states of individual decoder layers + rnn::States states_; // states of individual decoder layers Logits logProbs_; std::vector<Ptr<EncoderState>> encStates_; Ptr<data::CorpusBatch> batch_; - Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded + Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded Expr targetMask_; - Words targetWords_; // target labels + Words targetWords_; // target labels // Keep track of current target token position during translation size_t position_{0}; @@ -57,26 +60,30 @@ public: virtual ~DecoderState() {} // @TODO: Do we need all these to be virtual? - virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { - return encStates_; - } + virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; } virtual Logits getLogProbs() const { return logProbs_; } virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; } - // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop - virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] - const std::vector<IndexType>& batchIndices, // [batchIndex] - int beamSize) const { - + // @TODO: should this be a constructor? Then derived classes can call this without the New<> in + // the loop + virtual Ptr<DecoderState> select( + const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex] + const std::vector<IndexType>& batchIndices, // [batchIndex] + int beamSize) const { std::vector<Ptr<EncoderState>> newEncStates; for(auto& es : encStates_) - // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries - newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); + // If the size of the batch dimension of the encoder state context changed, subselect the + // correct batch entries + newEncStates.push_back( + es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices)); // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices - auto selectedState = New<DecoderState>( - states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_); + auto selectedState + = New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), + logProbs_, + newEncStates, + batch_); // Set positon of new state based on the target token position of current state selectedState->setPosition(getPosition()); @@ -86,7 +93,9 @@ public: virtual const rnn::States& getStates() const { return states_; } virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; }; - virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; } + virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { + targetHistoryEmbeddings_ = targetHistoryEmbeddings; + } virtual const Words& getTargetWords() const { return targetWords_; }; virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; } @@ -94,9 +103,7 @@ public: virtual Expr getTargetMask() const { return targetMask_; }; virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; } - virtual const Words& getSourceWords() const { - return getEncoderStates()[0]->getSourceWords(); - } + virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); } Ptr<data::CorpusBatch> getBatch() const { return batch_; } @@ -111,7 +118,8 @@ public: /** * Classifier output based on DecoderState - * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output. + * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have + * stateful output. */ class ClassifierState { private: diff --git a/src/models/transformer.h b/src/models/transformer.h index 6368cc6a..6368cc6a 100755..100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h diff --git a/src/models/transformer_factory.h b/src/models/transformer_factory.h index b282d819..b282d819 100755..100644 --- a/src/models/transformer_factory.h +++ b/src/models/transformer_factory.h diff --git a/src/models/transformer_stub.cpp b/src/models/transformer_stub.cpp index 871ee009..871ee009 100755..100644 --- a/src/models/transformer_stub.cpp +++ b/src/models/transformer_stub.cpp diff --git a/src/optimizers/exponential_smoothing.cpp b/src/optimizers/exponential_smoothing.cpp index 1120e7e4..1120e7e4 100755..100644 --- a/src/optimizers/exponential_smoothing.cpp +++ b/src/optimizers/exponential_smoothing.cpp diff --git a/src/optimizers/exponential_smoothing.h b/src/optimizers/exponential_smoothing.h index 5ef12ca1..5ef12ca1 100755..100644 --- a/src/optimizers/exponential_smoothing.h +++ b/src/optimizers/exponential_smoothing.h diff --git a/src/rnn/attention.h b/src/rnn/attention.h index 6b30cb55..6b30cb55 100755..100644 --- a/src/rnn/attention.h +++ b/src/rnn/attention.h diff --git a/src/rnn/cells.h b/src/rnn/cells.h index cddfd26e..cddfd26e 100755..100644 --- a/src/rnn/cells.h +++ b/src/rnn/cells.h diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h index beb1fce1..beb1fce1 100755..100644 --- a/src/rnn/constructors.h +++ b/src/rnn/constructors.h diff --git a/src/tensors/rand.cpp b/src/tensors/rand.cpp index e6dbc46e..e6dbc46e 100755..100644 --- a/src/tensors/rand.cpp +++ b/src/tensors/rand.cpp diff --git a/src/tensors/tensor.cpp b/src/tensors/tensor.cpp index 02de17bc..02de17bc 100755..100644 --- a/src/tensors/tensor.cpp +++ b/src/tensors/tensor.cpp diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index 10c3e7f1..10c3e7f1 100755..100644 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index 8c06761e..8c06761e 100755..100644 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h index df7865a7..df7865a7 100755..100644 --- a/src/training/graph_group_sync.h +++ b/src/training/graph_group_sync.h diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 9d2500f9..9d2500f9 100755..100644 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h diff --git a/src/training/validator.h b/src/training/validator.h index d6e64d69..d6e64d69 100755..100644 --- a/src/training/validator.h +++ b/src/training/validator.h diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index 5c1989a6..5c1989a6 100755..100644 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp diff --git a/src/translator/output_printer.h b/src/translator/output_printer.h index 603eedba..603eedba 100755..100644 --- a/src/translator/output_printer.h +++ b/src/translator/output_printer.h diff --git a/src/translator/scorers.h b/src/translator/scorers.h index a5a0be2c..a5a0be2c 100755..100644 --- a/src/translator/scorers.h +++ b/src/translator/scorers.h diff --git a/src/translator/translator.h b/src/translator/translator.h index 1ff19a4a..82d9343d 100755..100644 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -60,8 +60,7 @@ public: auto srcVocab = corpus_->getVocabs()[0]; if(options_->hasAndNotEmpty("shortlist")) - shortlistGenerator_ = New<data::LexicalShortlistGenerator>( - options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back()); + shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back()); auto devices = Config::getDevices(options_); numDevices_ = devices.size(); |