Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-03-18 06:41:24 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-03-18 06:41:24 +0300
commit272096c1d188dcd0ec33ba349bab5955c497876a (patch)
tree93adf1d89b1900a3d017acf8b5fc48c6f518ccf5 /src
parent77c3e356a47113f661dda794b815e84561ca93f5 (diff)
parent8f73923d3134f4799497b7e880963336b8fe4d6b (diff)
sync public and internal master
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt12
-rw-r--r--[-rwxr-xr-x]src/common/config_parser.cpp0
-rw-r--r--[-rwxr-xr-x]src/common/definitions.h0
-rw-r--r--[-rwxr-xr-x]src/common/file_stream.cpp0
-rw-r--r--[-rwxr-xr-x]src/common/io_item.h0
-rw-r--r--[-rwxr-xr-x]src/common/options.h0
-rw-r--r--src/common/timer.cpp0
-rw-r--r--[-rwxr-xr-x]src/common/utils.cpp0
-rw-r--r--[-rwxr-xr-x]src/data/batch.h0
-rw-r--r--[-rwxr-xr-x]src/data/corpus.cpp0
-rw-r--r--[-rwxr-xr-x]src/data/corpus_base.cpp0
-rw-r--r--[-rwxr-xr-x]src/data/factored_vocab.cpp1
-rw-r--r--[-rwxr-xr-x]src/data/factored_vocab.h0
-rw-r--r--src/data/shortlist.cpp153
-rw-r--r--src/data/shortlist.h47
-rw-r--r--[-rwxr-xr-x]src/data/vocab.cpp0
-rw-r--r--[-rwxr-xr-x]src/data/vocab.h0
-rw-r--r--[-rwxr-xr-x]src/data/vocab_base.h0
-rw-r--r--[-rwxr-xr-x]src/functional/operators.h0
-rw-r--r--[-rwxr-xr-x]src/functional/shape.h0
-rw-r--r--[-rwxr-xr-x]src/functional/tensor.h0
-rw-r--r--[-rwxr-xr-x]src/functional/tmp.h0
-rw-r--r--[-rwxr-xr-x]src/graph/auto_tuner.h0
-rw-r--r--[-rwxr-xr-x]src/graph/expression_operators.h0
-rw-r--r--[-rwxr-xr-x]src/graph/node.cpp0
-rw-r--r--[-rwxr-xr-x]src/graph/node_initializers.cpp0
-rw-r--r--[-rwxr-xr-x]src/graph/node_initializers.h0
-rw-r--r--[-rwxr-xr-x]src/layers/constructors.h70
-rw-r--r--src/layers/embedding.cpp194
-rw-r--r--src/layers/embedding.h157
-rw-r--r--[-rwxr-xr-x]src/layers/factory.h0
-rw-r--r--[-rwxr-xr-x]src/layers/generic.cpp607
-rw-r--r--[-rwxr-xr-x]src/layers/generic.h364
-rw-r--r--[-rwxr-xr-x]src/layers/guided_alignment.h0
-rw-r--r--src/layers/logits.cpp245
-rw-r--r--src/layers/logits.h106
-rw-r--r--[-rwxr-xr-x]src/layers/loss.cpp32
-rw-r--r--[-rwxr-xr-x]src/layers/loss.h181
-rw-r--r--src/layers/output.cpp293
-rw-r--r--src/layers/output.h75
-rw-r--r--[-rwxr-xr-x]src/microsoft/quicksand.cpp0
-rw-r--r--[-rwxr-xr-x]src/microsoft/quicksand.h0
-rw-r--r--src/microsoft/shortlist/logging/LoggerMacros.h25
-rw-r--r--src/microsoft/shortlist/utils/Converter.cpp59
-rw-r--r--src/microsoft/shortlist/utils/Converter.h83
-rw-r--r--src/microsoft/shortlist/utils/ParameterTree.cpp417
-rw-r--r--src/microsoft/shortlist/utils/ParameterTree.h185
-rw-r--r--src/microsoft/shortlist/utils/PrintTypes.h16
-rw-r--r--src/microsoft/shortlist/utils/StringUtils.cpp338
-rw-r--r--src/microsoft/shortlist/utils/StringUtils.h98
-rw-r--r--[-rwxr-xr-x]src/models/amun.h0
-rw-r--r--[-rwxr-xr-x]src/models/bert.h0
-rw-r--r--[-rwxr-xr-x]src/models/char_s2s.h0
-rw-r--r--[-rwxr-xr-x]src/models/classifier.h0
-rw-r--r--src/models/costs.cpp14
-rw-r--r--[-rwxr-xr-x]src/models/costs.h165
-rw-r--r--[-rwxr-xr-x]src/models/encoder_decoder.cpp0
-rw-r--r--[-rwxr-xr-x]src/models/encoder_decoder.h0
-rw-r--r--[-rwxr-xr-x]src/models/model_factory.cpp0
-rw-r--r--[-rwxr-xr-x]src/models/model_factory.h0
-rw-r--r--[-rwxr-xr-x]src/models/nematus.h0
-rw-r--r--[-rwxr-xr-x]src/models/s2s.h0
-rw-r--r--[-rwxr-xr-x]src/models/states.h70
-rw-r--r--[-rwxr-xr-x]src/models/transformer.h0
-rw-r--r--[-rwxr-xr-x]src/models/transformer_factory.h0
-rw-r--r--[-rwxr-xr-x]src/models/transformer_stub.cpp0
-rw-r--r--[-rwxr-xr-x]src/optimizers/exponential_smoothing.cpp0
-rw-r--r--[-rwxr-xr-x]src/optimizers/exponential_smoothing.h0
-rw-r--r--[-rwxr-xr-x]src/rnn/attention.h0
-rw-r--r--[-rwxr-xr-x]src/rnn/cells.h0
-rw-r--r--[-rwxr-xr-x]src/rnn/constructors.h0
-rw-r--r--[-rwxr-xr-x]src/tensors/rand.cpp0
-rw-r--r--[-rwxr-xr-x]src/tensors/tensor.cpp0
-rw-r--r--[-rwxr-xr-x]src/tensors/tensor.h0
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_sync.cpp0
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_sync.h0
-rw-r--r--[-rwxr-xr-x]src/training/scheduler.h0
-rw-r--r--[-rwxr-xr-x]src/training/validator.h0
-rw-r--r--[-rwxr-xr-x]src/translator/beam_search.cpp0
-rw-r--r--[-rwxr-xr-x]src/translator/output_printer.h0
-rw-r--r--[-rwxr-xr-x]src/translator/scorers.h0
-rw-r--r--[-rwxr-xr-x]src/translator/translator.h3
82 files changed, 2844 insertions, 1166 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index b47663b4..64b86a69 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -40,6 +40,7 @@ set(MARIAN_SOURCES
data/corpus_sqlite.cpp
data/corpus_nbest.cpp
data/text_input.cpp
+ data/shortlist.cpp
3rd_party/cnpy/cnpy.cpp
3rd_party/ExceptionWithCallStack.cpp
@@ -72,6 +73,9 @@ set(MARIAN_SOURCES
layers/loss.cpp
layers/weight.cpp
layers/lsh.cpp
+ layers/embedding.cpp
+ layers/output.cpp
+ layers/logits.cpp
rnn/cells.cpp
rnn/attention.cpp
@@ -84,6 +88,7 @@ set(MARIAN_SOURCES
models/model_factory.cpp
models/encoder_decoder.cpp
models/transformer_stub.cpp
+ models/costs.cpp
rescorer/score_collector.cpp
embedder/vector_collector.cpp
@@ -103,10 +108,15 @@ set(MARIAN_SOURCES
training/validator.cpp
training/communicator.cpp
- # this is only compiled to catch build errors, but not linked
+ # this is only compiled to catch build errors
microsoft/quicksand.cpp
microsoft/cosmos.cpp
+ # copied from quicksand to be able to read binary shortlist
+ microsoft/shortlist/utils/Converter.cpp
+ microsoft/shortlist/utils/StringUtils.cpp
+ microsoft/shortlist/utils/ParameterTree.cpp
+
$<TARGET_OBJECTS:libyaml-cpp>
$<TARGET_OBJECTS:SQLiteCpp>
$<TARGET_OBJECTS:pathie-cpp>
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index 3baa13ea..3baa13ea 100755..100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
diff --git a/src/common/definitions.h b/src/common/definitions.h
index d2cf8aa4..d2cf8aa4 100755..100644
--- a/src/common/definitions.h
+++ b/src/common/definitions.h
diff --git a/src/common/file_stream.cpp b/src/common/file_stream.cpp
index 78cbb12f..78cbb12f 100755..100644
--- a/src/common/file_stream.cpp
+++ b/src/common/file_stream.cpp
diff --git a/src/common/io_item.h b/src/common/io_item.h
index d86c01ac..d86c01ac 100755..100644
--- a/src/common/io_item.h
+++ b/src/common/io_item.h
diff --git a/src/common/options.h b/src/common/options.h
index 08c6a3ca..08c6a3ca 100755..100644
--- a/src/common/options.h
+++ b/src/common/options.h
diff --git a/src/common/timer.cpp b/src/common/timer.cpp
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/src/common/timer.cpp
diff --git a/src/common/utils.cpp b/src/common/utils.cpp
index 72624041..72624041 100755..100644
--- a/src/common/utils.cpp
+++ b/src/common/utils.cpp
diff --git a/src/data/batch.h b/src/data/batch.h
index 3c592b31..3c592b31 100755..100644
--- a/src/data/batch.h
+++ b/src/data/batch.h
diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp
index e8ce850b..e8ce850b 100755..100644
--- a/src/data/corpus.cpp
+++ b/src/data/corpus.cpp
diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp
index 5be4298b..5be4298b 100755..100644
--- a/src/data/corpus_base.cpp
+++ b/src/data/corpus_base.cpp
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp
index 818f3788..17a5bfb7 100755..100644
--- a/src/data/factored_vocab.cpp
+++ b/src/data/factored_vocab.cpp
@@ -546,7 +546,6 @@ void FactoredVocab::constructNormalizationInfoForVocab() {
/*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const {
for (; num-- > 0; ptr++) {
auto word = Word::fromWordIndex(*ptr);
- auto wordString = word2string(word);
auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first;
*ptr = (WordIndex)lemmaIndex;
}
diff --git a/src/data/factored_vocab.h b/src/data/factored_vocab.h
index 215e92f0..215e92f0 100755..100644
--- a/src/data/factored_vocab.h
+++ b/src/data/factored_vocab.h
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
new file mode 100644
index 00000000..6f551262
--- /dev/null
+++ b/src/data/shortlist.cpp
@@ -0,0 +1,153 @@
+#include "data/shortlist.h"
+#include "microsoft/shortlist/utils/ParameterTree.h"
+
+namespace marian {
+namespace data {
+
+// cast current void pointer to T pointer and move forward by num elements
+template <typename T>
+const T* get(const void*& current, size_t num = 1) {
+ const T* ptr = (const T*)current;
+ current = (const T*)current + num;
+ return ptr;
+}
+
+QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx,
+ size_t /*trgIdx*/,
+ bool /*shared*/)
+ : options_(options),
+ srcVocab_(srcVocab),
+ trgVocab_(trgVocab),
+ srcIdx_(srcIdx) {
+ std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
+
+ ABORT_IF(vals.empty(), "No path to filter path given");
+ std::string fname = vals[0];
+
+ auto firstNum = vals.size() > 1 ? std::stoi(vals[1]) : 0;
+ auto bestNum = vals.size() > 2 ? std::stoi(vals[2]) : 0;
+ float threshold = vals.size() > 3 ? std::stof(vals[3]) : 0;
+
+ if(firstNum != 0 || bestNum != 0 || threshold != 0) {
+ LOG(warn, "You have provided additional parameters for the Quicksand shortlist, but they are ignored.");
+ }
+
+ mmap_ = mio::mmap_source(fname); // memory-map the binary file once
+ const void* current = mmap_.data(); // pointer iterator over binary file
+
+ // compare magic number in binary file to make sure we are reading the right thing
+ const int32_t MAGIC_NUMBER = 1234567890;
+ int32_t header_magic_number = *get<int32_t>(current);
+ ABORT_IF(header_magic_number != MAGIC_NUMBER, "Trying to mmap Quicksand shortlist but encountered wrong magic number");
+
+ auto config = ::quicksand::ParameterTree::FromBinaryReader(current);
+ use16bit_ = config->GetBoolReq("use_16_bit");
+
+ LOG(info, "[data] Mapping Quicksand shortlist from {}", fname);
+
+ idSize_ = sizeof(int32_t);
+ if (use16bit_) {
+ idSize_ = sizeof(uint16_t);
+ }
+
+ // mmap the binary shortlist pieces
+ numDefaultIds_ = *get<int32_t>(current);
+ defaultIds_ = get<int32_t>(current, numDefaultIds_);
+ numSourceIds_ = *get<int32_t>(current);
+ sourceLengths_ = get<int32_t>(current, numSourceIds_);
+ sourceOffsets_ = get<int32_t>(current, numSourceIds_);
+ numShortlistIds_ = *get<int32_t>(current);
+ sourceToShortlistIds_ = get<uint8_t>(current, idSize_ * numShortlistIds_);
+
+ // display parameters
+ LOG(info,
+ "[data] Quicksand shortlist has {} source ids, {} default ids and {} shortlist ids",
+ numSourceIds_,
+ numDefaultIds_,
+ numShortlistIds_);
+}
+
+Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
+ auto srcBatch = (*batch)[srcIdx_];
+ auto maxShortlistSize = trgVocab_->size();
+
+ std::unordered_set<int32_t> indexSet;
+ for(int32_t i = 0; i < numDefaultIds_ && i < maxShortlistSize; ++i) {
+ int32_t id = defaultIds_[i];
+ indexSet.insert(id);
+ }
+
+ // State
+ std::vector<std::pair<const uint8_t*, int32_t>> curShortlists(maxShortlistSize);
+ auto curShortlistIt = curShortlists.begin();
+
+ // Because we might fill up our shortlist before reaching max_shortlist_size, we fill the shortlist in order of rank.
+ // E.g., first rank of word 0, first rank of word 1, ... second rank of word 0, ...
+ int32_t maxLength = 0;
+ for (Word word : srcBatch->data()) {
+ int32_t sourceId = (int32_t)word.toWordIndex();
+ srcVocab_->transcodeToShortlistInPlace((WordIndex*)&sourceId, 1);
+
+ if (sourceId < numSourceIds_) { // if it's a valid source id
+ const uint8_t* curShortlistIds = sourceToShortlistIds_ + idSize_ * sourceOffsets_[sourceId]; // start position for mapping
+ int32_t length = sourceLengths_[sourceId]; // how many mappings are there
+ curShortlistIt->first = curShortlistIds;
+ curShortlistIt->second = length;
+ curShortlistIt++;
+
+ if (length > maxLength)
+ maxLength = length;
+ }
+ }
+
+ // collect the actual shortlist mappings
+ for (int32_t i = 0; i < maxLength && indexSet.size() < maxShortlistSize; i++) {
+ for (int32_t j = 0; j < curShortlists.size() && indexSet.size() < maxShortlistSize; j++) {
+ int32_t length = curShortlists[j].second;
+ if (i < length) {
+ const uint8_t* source_shortlist_ids_bytes = curShortlists[j].first;
+ int32_t id = 0;
+ if (use16bit_) {
+ const uint16_t* source_shortlist_ids = reinterpret_cast<const uint16_t*>(source_shortlist_ids_bytes);
+ id = (int32_t)source_shortlist_ids[i];
+ }
+ else {
+ const int32_t* source_shortlist_ids = reinterpret_cast<const int32_t*>(source_shortlist_ids_bytes);
+ id = source_shortlist_ids[i];
+ }
+ indexSet.insert(id);
+ }
+ }
+ }
+
+ // turn into vector and sort (selected indices)
+ std::vector<WordIndex> indices;
+ indices.reserve(indexSet.size());
+ for(auto i : indexSet)
+ indices.push_back((WordIndex)i);
+
+ std::sort(indices.begin(), indices.end());
+ return New<Shortlist>(indices);
+}
+
+Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx,
+ size_t trgIdx,
+ bool shared) {
+ std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
+ ABORT_IF(vals.empty(), "No path to shortlist given");
+ std::string fname = vals[0];
+ if(filesystem::Path(fname).extension().string() == ".bin") {
+ return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
+ } else {
+ return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
+ }
+}
+
+} // namespace data
+} // namespace marian
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 395bcfee..ab6a087b 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -5,6 +5,7 @@
#include "common/file_stream.h"
#include "data/corpus_base.h"
#include "data/types.h"
+#include "mio/mio.hpp"
#include <random>
#include <unordered_map>
@@ -292,5 +293,51 @@ public:
}
};
+/*
+Legacy binary shortlist for Microsoft-internal use.
+*/
+class QuicksandShortlistGenerator : public ShortlistGenerator {
+private:
+ Ptr<Options> options_;
+ Ptr<const Vocab> srcVocab_;
+ Ptr<const Vocab> trgVocab_;
+
+ size_t srcIdx_;
+
+ mio::mmap_source mmap_;
+
+ // all the quicksand bits go here
+ bool use16bit_{false};
+ int32_t numDefaultIds_;
+ int32_t idSize_;
+ const int32_t* defaultIds_{nullptr};
+ int32_t numSourceIds_{0};
+ const int32_t* sourceLengths_{nullptr};
+ const int32_t* sourceOffsets_{nullptr};
+ int32_t numShortlistIds_{0};
+ const uint8_t* sourceToShortlistIds_{nullptr};
+
+public:
+ QuicksandShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx = 0,
+ size_t trgIdx = 1,
+ bool shared = false);
+
+ virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
+};
+
+/*
+Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
+unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
+*/
+Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
+ size_t srcIdx = 0,
+ size_t trgIdx = 1,
+ bool shared = false);
+
} // namespace data
} // namespace marian
diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp
index 07ac479e..07ac479e 100755..100644
--- a/src/data/vocab.cpp
+++ b/src/data/vocab.cpp
diff --git a/src/data/vocab.h b/src/data/vocab.h
index 9a40ba16..9a40ba16 100755..100644
--- a/src/data/vocab.h
+++ b/src/data/vocab.h
diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h
index 8c214c97..8c214c97 100755..100644
--- a/src/data/vocab_base.h
+++ b/src/data/vocab_base.h
diff --git a/src/functional/operators.h b/src/functional/operators.h
index a14f153f..a14f153f 100755..100644
--- a/src/functional/operators.h
+++ b/src/functional/operators.h
diff --git a/src/functional/shape.h b/src/functional/shape.h
index fd354e1e..fd354e1e 100755..100644
--- a/src/functional/shape.h
+++ b/src/functional/shape.h
diff --git a/src/functional/tensor.h b/src/functional/tensor.h
index f5549c60..f5549c60 100755..100644
--- a/src/functional/tensor.h
+++ b/src/functional/tensor.h
diff --git a/src/functional/tmp.h b/src/functional/tmp.h
index a83c0ff4..a83c0ff4 100755..100644
--- a/src/functional/tmp.h
+++ b/src/functional/tmp.h
diff --git a/src/graph/auto_tuner.h b/src/graph/auto_tuner.h
index 01f33085..01f33085 100755..100644
--- a/src/graph/auto_tuner.h
+++ b/src/graph/auto_tuner.h
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index ca0739e4..ca0739e4 100755..100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 257a639f..257a639f 100755..100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp
index 4e39d1bf..4e39d1bf 100755..100644
--- a/src/graph/node_initializers.cpp
+++ b/src/graph/node_initializers.cpp
diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h
index 7cdb4183..7cdb4183 100755..100644
--- a/src/graph/node_initializers.h
+++ b/src/graph/node_initializers.h
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index a2c38197..9e9de207 100755..100644
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -1,7 +1,9 @@
#pragma once
+#include "layers/embedding.h"
#include "layers/factory.h"
#include "layers/generic.h"
+#include "layers/output.h"
namespace marian {
namespace mlp {
@@ -43,6 +45,7 @@ struct LogitLayerFactory : public Factory {
// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
class OutputFactory : public LogitLayerFactory {
using LogitLayerFactory::LogitLayerFactory;
+
protected:
std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
@@ -53,9 +56,7 @@ public:
return Accumulator<OutputFactory>(*this);
}
- void setShortlist(Ptr<data::Shortlist> shortlist) {
- shortlist_ = shortlist;
- }
+ void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
auto output = New<Output>(graph, options_);
@@ -87,8 +88,7 @@ protected:
std::vector<Ptr<IUnaryLayer>> layers_;
public:
- MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : graph_(graph), options_(options) {}
+ MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
Expr apply(const std::vector<Expr>& av) override {
Expr output;
@@ -104,46 +104,53 @@ public:
}
Logits applyAsLogits(const std::vector<Expr>& av) override {
- // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
+ // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different
+ // return type
auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
- ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
- if (layers_.size() == 1) {
- if (av.size() == 1)
+ ABORT_IF(
+ !lastLayer,
+ "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
+ if(layers_.size() == 1) {
+ if(av.size() == 1)
return lastLayer->applyAsLogits(av[0]);
else
return lastLayer->applyAsLogits(av);
- }
- else {
+ } else {
Expr output;
- if (av.size() == 1)
+ if(av.size() == 1)
output = layers_[0]->apply(av[0]);
else
output = layers_[0]->apply(av);
- for (size_t i = 1; i < layers_.size() - 1; ++i)
+ for(size_t i = 1; i < layers_.size() - 1; ++i)
output = layers_[i]->apply(output);
return lastLayer->applyAsLogits(output);
}
}
- Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
- Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
+ Expr apply(Expr e) override { return apply(std::vector<Expr>{e}); }
+ Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{e}); }
void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
void setShortlist(Ptr<data::Shortlist> shortlist) override final {
auto p = tryAsHasShortlist();
- ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
+ ABORT_IF(
+ !p,
+ "setShortlist() called on an MLP with an output layer that does not support short lists");
p->setShortlist(shortlist);
}
void clear() override final {
auto p = tryAsHasShortlist();
- if (p)
+ if(p)
p->clear();
}
+
private:
- Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
+ Ptr<IHasShortList> tryAsHasShortlist() const {
+ return std::dynamic_pointer_cast<IHasShortList>(layers_.back());
+ }
};
/**
@@ -152,6 +159,7 @@ private:
*/
class MLPFactory : public Factory {
using Factory::Factory;
+
private:
std::vector<Ptr<LayerFactory>> layers_;
@@ -175,23 +183,27 @@ public:
// which will go away if we get rid of the abstract factories, and instead just construct
// all layers immediately, which is my long-term goal for Marian.
private:
- template<class WrappedFactory>
+ template <class WrappedFactory>
class AsLayerFactory : public LayerFactory {
- WrappedFactory us;
+ WrappedFactory us;
+
public:
- AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
- Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
- auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
- ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
- return p;
- }
+ AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
+ Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
+ auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
+ ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
+ return p;
+ }
};
- template<class WrappedFactory>
- static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
+ template <class WrappedFactory>
+ static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) {
+ return wrapped;
+ }
+
public:
Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
push_back(AsLayerFactory<OutputFactory>(lf));
- //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
+ // layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
return Accumulator<MLPFactory>(*this);
}
};
diff --git a/src/layers/embedding.cpp b/src/layers/embedding.cpp
new file mode 100644
index 00000000..92c4ad6d
--- /dev/null
+++ b/src/layers/embedding.cpp
@@ -0,0 +1,194 @@
+#include "embedding.h"
+#include "data/factored_vocab.h"
+
+namespace marian {
+
+Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options), inference_(opt<bool>("inference")) {
+ std::string name = opt<std::string>("prefix");
+ int dimVoc = opt<int>("dimVocab");
+ int dimEmb = opt<int>("dimEmb");
+
+ bool fixed = opt<bool>("fixed", false);
+
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if(factoredVocab_) {
+ dimVoc = (int)factoredVocab_->factorVocabSize();
+ LOG_ONCE(info, "[embedding] Factored embeddings enabled");
+ }
+
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(
+ /*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
+
+ if(options_->has("embFile")) {
+ std::string file = opt<std::string>("embFile");
+ if(!file.empty()) {
+ bool norm = opt<bool>("normalization", false);
+ initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
+ }
+ }
+
+ E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
+}
+
+// helper to embed a sequence of words (given as indices) via factored embeddings
+Expr Embedding::multiRows(const Words& data, float dropProb) const {
+ auto graph = E_->graph();
+ auto factoredData = factoredVocab_->csr_rows(data);
+ // multi-hot factor vectors are represented as a sparse CSR matrix
+ // [row index = word position index] -> set of factor indices for word at this position
+ ABORT_IF(factoredData.shape
+ != Shape({(int)factoredData.offsets.size() - 1 /*=rows of CSR*/, E_->shape()[0]}),
+ "shape mismatch??");
+ // the CSR matrix is passed in pieces
+ auto weights = graph->constant({(int)factoredData.weights.size()},
+ inits::fromVector(factoredData.weights));
+ auto indices = graph->constant(
+ {(int)factoredData.indices.size()}, inits::fromVector(factoredData.indices), Type::uint32);
+ auto offsets = graph->constant(
+ {(int)factoredData.offsets.size()}, inits::fromVector(factoredData.offsets), Type::uint32);
+ // apply dropout
+ // We apply it to the weights, i.e. factors get dropped out separately, but always as entire
+ // vectors.
+ if(!inference_)
+ weights = dropout(weights, dropProb);
+ // perform the product
+ return csr_dot(factoredData.shape, weights, indices, offsets, E_);
+}
+
+std::tuple<Expr /*embeddings*/, Expr /*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const
+/*override final*/ {
+ auto graph = E_->graph();
+ int dimBatch = (int)subBatch->batchSize();
+ int dimEmb = E_->shape()[-1];
+ int dimWidth = (int)subBatch->batchWidth();
+
+ // factored embeddings:
+ // - regular:
+ // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
+ // - factored:
+ // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
+ // - each row of M contains the set of factors for one word => we want a CSR matrix
+ // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
+ // - first compute x @ M on the CPU
+ // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
+ // - shape (U, specifically) not actually needed here
+ // - foreach input x[i]
+ // - locate row M[i,*]
+ // - copy through its index values (std::vector<push_back>)
+ // - create a matching ones vector (we can keep growing)
+ // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
+ // - CSR matrix product with E
+ // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
+ // - double-check if all dimensions are specified. Probably not for transpose (which would
+ // be like csc_dot()).
+ // - weighting:
+ // - core factors' gradients are sums over all words that use the factors;
+ // - core factors' embeddings move very fast
+ // - words will need to make up for the move; rare words cannot
+ // - so, we multiply each factor with 1/refCount
+ // - core factors get weighed down a lot
+ // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as
+ // before
+ // - but forward pass weighs them down, so that all factors are in a similar numeric range
+ // - if it is required to be in a different range, the embeddings can still learn that, but
+ // more slowly
+
+ auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
+#if 1
+ auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->mask()));
+#else // @TODO: this is dead code now, get rid of it
+ // experimental: hide inline-fix source tokens from cross attention
+ auto batchMask
+ = graph->constant({dimWidth, dimBatch, 1},
+ inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
+#endif
+ // give the graph inputs readable names for debugging and ONNX
+ batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
+
+ return std::make_tuple(batchEmbeddings, batchMask);
+}
+
+Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
+ if(factoredVocab_) {
+ Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ // selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), {
+ // selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
+ return selectedEmbs;
+ } else
+ return applyIndices(toWordIndexVector(words), shape);
+}
+
+Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const
+/*override final*/ {
+ ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
+ auto embIdxExpr = E_->graph()->indices(embIdx);
+ embIdxExpr->set_name("data_"
+ + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
+ auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape()
+ // (test that separately)
+ if(!inference_)
+ selectedEmbs = dropout(
+ selectedEmbs, options_->get<float>("dropout", 0.0f), {selectedEmbs->shape()[-3], 1, 1});
+ return selectedEmbs;
+}
+
+// standard encoder word embeddings
+/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
+ // clang-format off
+ auto options = New<Options>(
+ "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
+ "dimEmb", opt<int>("dim-emb"),
+ "dropout", dropoutEmbeddings_,
+ "inference", inference_,
+ "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb"
+ : prefix_ + "_Wemb",
+ "fixed", embeddingFix_,
+ "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
+ // clang-format on
+ if(options_->hasAndNotEmpty("embedding-vectors")) {
+ auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
+ options->set(
+ "embFile", embFiles[batchIndex_], "normalization", opt<bool>("embedding-normalization"));
+ }
+ return New<Embedding>(graph_, options);
+}
+
+// ULR word embeddings
+/*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
+ // clang-format off
+ return New<ULREmbedding>(graph_, New<Options>(
+ "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
+ "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
+ "dimUlrEmb", opt<int>("ulr-dim-emb"),
+ "dimEmb", opt<int>("dim-emb"),
+ "ulr-dropout", opt<float>("ulr-dropout"),
+ "dropout", dropoutEmbeddings_,
+ "inference", inference_,
+ "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
+ "ulrQueryFile", opt<std::string>("ulr-query-vectors"),
+ "ulrKeysFile", opt<std::string>("ulr-keys-vectors")
+ ));
+ // clang-format on
+}
+
+// get embedding layer for this encoder or decoder
+// This is lazy mostly because the constructors of the consuming objects are not
+// guaranteed presently to have access to their graph.
+Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
+ if(embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
+ if(embeddingLayers_.size() <= batchIndex_)
+ embeddingLayers_.resize(batchIndex_ + 1);
+ if(ulr)
+ embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
+ else
+ embeddingLayers_[batchIndex_] = createEmbeddingLayer();
+ }
+ return embeddingLayers_[batchIndex_];
+}
+
+} // namespace marian
diff --git a/src/layers/embedding.h b/src/layers/embedding.h
new file mode 100644
index 00000000..2fa7b78d
--- /dev/null
+++ b/src/layers/embedding.h
@@ -0,0 +1,157 @@
+#pragma once
+#include "generic.h"
+#include "marian.h"
+
+namespace marian {
+
+class FactoredVocab;
+
+// A regular embedding layer.
+// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
+// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
+// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
+class Embedding : public LayerBase, public IEmbeddingLayer {
+ Expr E_;
+ Ptr<FactoredVocab> factoredVocab_;
+ Expr multiRows(const Words& data, float dropProb) const;
+ bool inference_{false};
+
+public:
+ Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
+
+ std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const override final;
+
+ Expr apply(const Words& words, const Shape& shape) const override final;
+
+ Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
+};
+
+class ULREmbedding : public LayerBase, public IEmbeddingLayer {
+ std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
+ bool inference_{false};
+
+public:
+ ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options), inference_(opt<bool>("inference")) {
+ std::string name = "url_embed"; // opt<std::string>("prefix");
+ int dimKeys = opt<int>("dimTgtVoc");
+ int dimQueries = opt<int>("dimSrcVoc");
+ int dimEmb = opt<int>("dimEmb");
+ int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
+ bool fixed = opt<bool>("fixed", false);
+
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);
+
+ std::string queryFile = opt<std::string>("ulrQueryFile");
+ std::string keyFile = opt<std::string>("ulrKeysFile");
+ bool trainTrans = opt<bool>("ulrTrainTransform", false);
+ if(!queryFile.empty() && !keyFile.empty()) {
+ initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
+ name = "ulr_query";
+ fixed = true;
+ auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed);
+ ulrEmbeddings_.push_back(query_embed);
+ // keys embeds
+ initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
+ name = "ulr_keys";
+ fixed = true;
+ auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed);
+ ulrEmbeddings_.push_back(key_embed);
+ // actual trainable embedding
+ initFunc = inits::glorotUniform();
+ name = "ulr_embed";
+ fixed = false;
+ auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed); // note the reverse dim
+ ulrEmbeddings_.push_back(ulr_embed);
+ // init trainable src embedding
+ name = "ulr_src_embed";
+ auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed);
+ ulrEmbeddings_.push_back(ulr_src_embed);
+ // ulr transformation matrix
+ // initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall
+ // we make this to the fixed case only
+ if(trainTrans) {
+ initFunc = inits::glorotUniform();
+ fixed = false;
+ } else {
+ initFunc = inits::eye(); // identity matrix
+ fixed = true;
+ }
+ name = "ulr_transform";
+ auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed);
+ ulrEmbeddings_.push_back(ulrTransform);
+
+ initFunc = inits::fromValue(
+ 1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no
+ // universal embeddings - should be zero for top freq only
+ fixed = true;
+ name = "ulr_shared";
+ auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed);
+ ulrEmbeddings_.push_back(share_embed);
+ }
+ }
+
+ std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const override final {
+ auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
+ auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
+ auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
+ auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
+ auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
+ auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
+ int dimBatch = (int)subBatch->batchSize();
+ int dimEmb = uniEmbed->shape()[-1];
+ int dimWords = (int)subBatch->batchWidth();
+ // D = K.A.QT
+ // dimm(K) = univ_tok_vocab*uni_embed_size
+ // dim A = uni_embed_size*uni_embed_size
+ // dim Q: uni_embed_size * total_merged_vocab_size
+ // dim D = univ_tok_vocab * total_merged_vocab_size
+ // note all above can be precombuted and serialized if A is not trainiable and during decoding
+ // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this
+ // minibatch from Q
+ auto embIdx = toWordIndexVector(subBatch->data());
+ auto queryEmbeddings = rows(queryEmbed, embIdx);
+ auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
+ auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
+ auto qt = dot(queryEmbeddings, ulrTransform, false, false); // A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
+ auto sqrtDim = std::sqrt((float)queryEmbeddings->shape()[-1]);
+ qt = qt / sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in
+ // magnitude with larger embeds sizes
+ auto z = dot(qt, keyEmbed, false, true); // query-key similarity
+ float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
+ if(!inference_)
+ z = dropout(z, dropProb);
+
+ float tau
+ = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
+ // temperature in softmax is to control randomness of predictions
+ // high temperature Softmax outputs are more close to each other
+ // low temperatures the softmax become more similar to "hardmax"
+ auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
+ auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
+ auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
+ auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb});
+ auto graph = ulrEmbeddings_.front()->graph();
+ auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask()));
+ if(!inference_)
+ batchEmbeddings = dropout(batchEmbeddings,
+ options_->get<float>("dropout-embeddings", 0.0f),
+ {batchEmbeddings->shape()[-3], 1, 1});
+ return std::make_tuple(batchEmbeddings, batchMask);
+ }
+
+ Expr apply(const Words& words, const Shape& shape) const override final {
+ return applyIndices(toWordIndexVector(words), shape);
+ }
+
+ Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
+ embIdx;
+ shape;
+ ABORT("not implemented"); // @TODO: implement me
+ }
+};
+
+} // namespace marian
diff --git a/src/layers/factory.h b/src/layers/factory.h
index f9e4ddf9..f9e4ddf9 100755..100644
--- a/src/layers/factory.h
+++ b/src/layers/factory.h
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
index d44f4020..8e2ecfd7 100755..100644
--- a/src/layers/generic.cpp
+++ b/src/layers/generic.cpp
@@ -1,609 +1,10 @@
#include "marian.h"
-#include "layers/generic.h"
+#include "data/factored_vocab.h"
#include "layers/constructors.h"
+#include "layers/generic.h"
#include "layers/loss.h"
-#include "data/factored_vocab.h"
-#include "rnn/types.h" // for State::select()
-#include "models/states.h" // for EncoderState
#include "layers/lsh.h"
+#include "models/states.h" // for EncoderState
-namespace marian {
- Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count)
-
- Ptr<ExpressionGraph> Logits::graph() const {
- ABORT_IF(logits_.empty(), "Empty logits object??");
- return logits_.front()->loss()->graph();
- }
-
- // This function assumes that the object holds one or more factor logits.
- // It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
- Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
- LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
-
- auto firstLogits = logits_.front()->loss();
- ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
- "Labels not matching logits shape ({} != {}, {})??",
- labels.size() * firstLogits->shape()[-1],
- firstLogits->shape().elements(),
- firstLogits->shape());
-
- // base case (no factors)
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return lossFn(firstLogits, indices(toWordIndexVector(labels)));
- }
-
- auto numGroups = factoredVocab_->getNumGroups();
-
- // split labels into individual factor labels
- auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
-
- //Expr indices = this->indices(toWordIndexVector(labels));
- // accumulate all CEs for all words that have the factor
- // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
- Expr loss;
- for (size_t g = 0; g < numGroups; g++) {
- if (!logits_[g])
- continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
- const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
- auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
- auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
- auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
- // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
- auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
- if(loss)
- factorLoss = cast(factorLoss, loss->value_type());
- factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor
- loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
- }
- return loss;
- }
-
- // This function assumes this object holds a single factor that represents a rational loss (with count).
- //Ptr<RationalLoss> Logits::getRationalLoss() const {
- // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs");
- // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count");
- // return logits_.front();
- //}
-
- // get logits for one factor group
- // For groupIndex == 0, the function also requires the shortlist if there is one.
- Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
-
- auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
-
- // normalize for decoding:
- // - all secondary factors: subtract their max
- // - lemma: add all maxes of applicable factors
- if (groupIndex > 0) {
- sel = sel - max(sel, -1);
- }
- else {
- auto numGroups = getNumFactorGroups();
- for (size_t g = 1; g < numGroups; g++) {
- auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32
- auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
- sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0
- }
- }
-
- // if selIdx are given, then we must reshuffle accordingly
- if (!hypIndices.empty()) // use the same function that shuffles decoder state
- sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
-
- return sel;
- }
-
- // used for breakDown() only
- // Index is flattened
- Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
- return logits_[groupIndex]->loss()->val();
- }
-
- // This function assumes that the object holds one or more factor logits, which are summed up
- // into output-vocab logits according to the factored model (with correct normalization of factors).
- // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
- // @TODO: remove altogether
- Expr Logits::getLogits() const {
- ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return getFactoredLogits(0);
- }
-
-#ifdef FACTOR_FULL_EXPANSION
- // compute normalized factor log probs
- std::vector<Expr> logProbs(logits_.size());
- for (size_t g = 0; g < logits_.size(); g++)
- logProbs[g] = logsoftmax(logits_[g]->loss());
- auto y = concatenate(logProbs, /*axis=*/ -1);
-
- // sum up the unit logits across factors for each target word
- auto graph = y->graph();
- auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
- y = dot_csr(
- y, // [B x U]
- factorMatrix.shape,
- graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
- graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
- graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
- /*transB=*/ true); // -> [B x V]
-
- // mask out gaps
- auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
- y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask));
-
- return y;
-#else
- ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
-#endif
- }
-
- void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
- bool isValid = FactoredVocab::isFactorValid(factorIndex);
- indices.push_back(isValid ? (WordIndex)factorIndex : 0);
- masks.push_back((float)isValid);
- }
-
- std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
- if (!factoredVocab_) {
- ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
- return {MaskedFactorIndices(words)};
- }
- auto numGroups = factoredVocab_->getNumGroups();
- std::vector<MaskedFactorIndices> res(numGroups);
- for (size_t g = 0; g < numGroups; g++) {
- auto& resg = res[g];
- resg.reserve(words.size());
- for (const auto& word : words)
- resg.push_back(factoredVocab_->getFactor(word, g));
- }
- return res;
- }
-
- //// use first factor of each word to determine whether it has a specific factor
- //std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
- // std::vector<float> res;
- // res.reserve(words.size());
- // for (const auto& word : words) {
- // auto lemma = factoredVocab_->getFactor(word, 0);
- // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
- // }
- // return res;
- //}
-
- // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
- // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
- std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
- size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
- std::vector<float> res;
- res.reserve(n);
- // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
- for (size_t i = 0; i < n; i++) {
- auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
- res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
- }
- return res;
- }
-
- Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
- std::vector<Ptr<RationalLoss>> newLogits;
- for (const auto& l : logits_)
- newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
- return Logits(std::move(newLogits), factoredVocab_);
- }
-
- Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
- std::vector<Ptr<RationalLoss>> newLogits;
- bool first = true;
- for (const auto& l : logits_) {
- newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
- first = false;
- }
- return Logits(std::move(newLogits), factoredVocab_);
- }
-
- // @TODO: code dup with above; we can merge it into applyToRationalLoss()
- Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
- std::vector<Ptr<RationalLoss>> newLogits;
- for (const auto& l : logits_)
- newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
- return Logits(std::move(newLogits), factoredVocab_);
- }
-
- namespace mlp {
- /*private*/ void Output::lazyConstruct(int inputDim) {
- // We must construct lazily since we won't know tying nor input dim in constructor.
- if (Wt_)
- return;
-
- // this option is only set in the decoder
- if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
- auto k = opt<std::vector<int>>("output-approx-knn")[0];
- auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
- lsh_ = New<LSH>(k, nbits);
- }
-
- auto name = options_->get<std::string>("prefix");
- auto numOutputClasses = options_->get<int>("dim");
-
- factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
- if (factoredVocab_) {
- numOutputClasses = (int)factoredVocab_->factorVocabSize();
- LOG_ONCE(info, "[embedding] Factored outputs enabled");
- }
-
- if(tiedParam_) {
- Wt_ = tiedParam_;
- } else {
- if (graph_->get(name + "_W")) { // support of legacy models that did not transpose
- Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
- isLegacyUntransposedW = true;
- }
- else // this is the regular case:
- Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
- }
-
- if(hasBias_)
- b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
-
- /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
- ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
- if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
-#define HARDMAX_HACK
-#ifdef HARDMAX_HACK
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
-#endif
- auto range = factoredVocab_->getGroupRange(0);
- auto lemmaVocabDim = (int)(range.second - range.first);
- auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
- lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
- }
- }
-
- Logits Output::applyAsLogits(Expr input) /*override final*/ {
- lazyConstruct(input->shape()[-1]);
-
- auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
- if(b)
- return affine(x, W, b, transA, transB);
- else
- return dot(x, W, transA, transB);
- };
-
- auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
- if(lsh_) {
- ABORT_IF( transA, "Transposed query not supported for LSH");
- ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
- return lsh_->apply(x, W, b); // knows how to deal with undefined bias
- } else {
- return affineOrDot(x, W, b, transA, transB);
- }
- };
-
- if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
- cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
- if(hasBias_)
- cachedShortb_ = index_select(b_ , -1, shortlist_->indices());
- }
-
- if (factoredVocab_) {
- auto graph = input->graph();
-
- // project each factor separately
- auto numGroups = factoredVocab_->getNumGroups();
- std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors)
- Expr input1 = input; // [B... x D]
- Expr Plemma = nullptr; // used for lemmaDimEmb=-1
- Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
- for (size_t g = 0; g < numGroups; g++) {
- auto range = factoredVocab_->getGroupRange(g);
- if (g > 0 && range.first == range.second) // empty entry
- continue;
- ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g);
- // slice this group's section out of W_
- Expr factorWt, factorB;
- if (g == 0 && shortlist_) {
- factorWt = cachedShortWt_;
- factorB = cachedShortb_;
- }
- else {
- factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
- if(hasBias_)
- factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
- }
- /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
- if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
- LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
- // this mimics one transformer layer
- // - attention over two inputs:
- // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas.
- // - input = hidden state FF(h_enc+h_dec)
- // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention)
- // - multi-head to allow for multiple conditions to be modeled
- // - add & norm, for gradient flow and scaling
- // - FF layer --this is expensive; it is per-factor
- // multi-head attention
- int inputDim = input->shape()[-1];
- int heads = 8;
- auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
- auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform());
- auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform());
- auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform());
- auto toMultiHead = [&](Expr x, int heads) {
- const auto& shape = x->shape();
- int inputDim = shape[-1];
- int otherDim = shape.elements() / inputDim;
- ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads);
- return reshape(x, { otherDim, heads, 1, inputDim / heads });
- };
- input1 = inputLemma;
- auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
- auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax.
- auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values
- auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other
- auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
- auto sm = sigmoid(zm); // [B... x H x 1]
- auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
- auto r = reshape(rm, input->shape()); // [B... x D]
- // add & norm
- input1 = r + input1;
- input1 = layerNorm(input1, name + "_att");
- // FF layer
- auto ffnDropProb = 0.1f; // @TODO: get as a parameter
- auto ffnDim = inputDim * 2; // @TODO: get as a parameter
- auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, inits::glorotUniform(), (ActivationFunction*)relu, ffnDropProb);
- f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
- // add & norm
- input1 = f + input1;
- input1 = layerNorm(input1, name + "_ffn");
- }
- // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix
- Expr factorLogits;
- if(g == 0)
- factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
- else
- factorLogits = affineOrDot(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
-
- // optionally add lemma-dependent bias
- if (Plemma) { // [B... x U0]
- int lemmaVocabDim = Plemma->shape()[-1];
- int factorVocabDim = factorLogits->shape()[-1];
- auto name = options_->get<std::string>("prefix");
- Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
- auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
- factorLogits = factorLogits + b;
- }
- allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
- // optionally add a soft embedding of lemma back to create some lemma dependency
- // @TODO: if this works, move it into lazyConstruct
- if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
- LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
- // get expected lemma embedding vector
- auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
- auto factorSoftmax = exp(factorLogSoftmax);
- inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- }
- else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
- LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
- // get max-lemma embedding vector
- auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set
- auto factorHardmax = eq(factorLogits, maxVal);
- inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
- }
- else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
- ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
- LOG_ONCE(info, "[embedding] using lemma-dependent bias");
- auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
- auto z = /*stopGradient*/(factorLogSoftmax);
- Plemma = exp(z); // [B... x U]
- }
- else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
- LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
- // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
- auto factorLogSoftmax = logsoftmax(factorLogits);
- auto factorSoftmax = exp(factorLogSoftmax);
-#ifdef HARDMAX_HACK
- bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
- if (hardmax) {
- lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
- LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
- auto maxVal = max(factorSoftmax, -1);
- factorSoftmax = eq(factorSoftmax, maxVal);
- }
-#endif
- // re-embedding lookup, soft-indexed by softmax
- if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
- cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
- auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
- // project it back to regular hidden dim
- int inputDim = input1->shape()[-1];
- auto name = options_->get<std::string>("prefix");
- // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1
- Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension
- auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
- // augment the original hidden vector with this additional information
- input1 = input1 + f;
- }
- }
- return Logits(std::move(allLogits), factoredVocab_);
- } else if (shortlist_) {
- return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true));
- } else {
- return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
- }
- }
- }
-
- Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options), inference_(opt<bool>("inference")) {
- std::string name = opt<std::string>("prefix");
- int dimVoc = opt<int>("dimVocab");
- int dimEmb = opt<int>("dimEmb");
-
- bool fixed = opt<bool>("fixed", false);
-
- factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
- if (factoredVocab_) {
- dimVoc = (int)factoredVocab_->factorVocabSize();
- LOG_ONCE(info, "[embedding] Factored embeddings enabled");
- }
-
- // Embedding layer initialization should depend only on embedding size, hence fanIn=false
- auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
-
- if (options_->has("embFile")) {
- std::string file = opt<std::string>("embFile");
- if (!file.empty()) {
- bool norm = opt<bool>("normalization", false);
- initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
- }
- }
-
- E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
- }
-
- // helper to embed a sequence of words (given as indices) via factored embeddings
- Expr Embedding::multiRows(const Words& data, float dropProb) const {
- auto graph = E_->graph();
- auto factoredData = factoredVocab_->csr_rows(data);
- // multi-hot factor vectors are represented as a sparse CSR matrix
- // [row index = word position index] -> set of factor indices for word at this position
- ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
- // the CSR matrix is passed in pieces
- auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights));
- auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
- auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
- // apply dropout
- // We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
- if(!inference_)
- weights = dropout(weights, dropProb);
- // perform the product
- return csr_dot(factoredData.shape, weights, indices, offsets, E_);
- }
-
- std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
- auto graph = E_->graph();
- int dimBatch = (int)subBatch->batchSize();
- int dimEmb = E_->shape()[-1];
- int dimWidth = (int)subBatch->batchWidth();
-
- // factored embeddings:
- // - regular:
- // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
- // - factored:
- // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
- // - each row of M contains the set of factors for one word => we want a CSR matrix
- // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
- // - first compute x @ M on the CPU
- // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
- // - shape (U, specifically) not actually needed here
- // - foreach input x[i]
- // - locate row M[i,*]
- // - copy through its index values (std::vector<push_back>)
- // - create a matching ones vector (we can keep growing)
- // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
- // - CSR matrix product with E
- // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
- // - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
- // - weighting:
- // - core factors' gradients are sums over all words that use the factors;
- // - core factors' embeddings move very fast
- // - words will need to make up for the move; rare words cannot
- // - so, we multiply each factor with 1/refCount
- // - core factors get weighed down a lot
- // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
- // - but forward pass weighs them down, so that all factors are in a similar numeric range
- // - if it is required to be in a different range, the embeddings can still learn that, but more slowly
-
- auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
-#if 1
- auto batchMask = graph->constant({dimWidth, dimBatch, 1},
- inits::fromVector(subBatch->mask()));
-#else // @TODO: this is dead code now, get rid of it
- // experimental: hide inline-fix source tokens from cross attention
- auto batchMask = graph->constant({dimWidth, dimBatch, 1},
- inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
-#endif
- // give the graph inputs readable names for debugging and ONNX
- batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask");
-
- return std::make_tuple(batchEmbeddings, batchMask);
- }
-
- Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
- if (factoredVocab_) {
- Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
- selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
- //selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
- return selectedEmbs;
- }
- else
- return applyIndices(toWordIndexVector(words), shape);
- }
-
- Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
- ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
- auto embIdxExpr = E_->graph()->indices(embIdx);
- embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index?
- auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
- selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
- // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
- if(!inference_)
- selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
- return selectedEmbs;
- }
-
- // standard encoder word embeddings
- /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
- auto options = New<Options>(
- "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
- "dimEmb", opt<int>("dim-emb"),
- "dropout", dropoutEmbeddings_,
- "inference", inference_,
- "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
- "fixed", embeddingFix_,
- "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
- if(options_->hasAndNotEmpty("embedding-vectors")) {
- auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
- options->set(
- "embFile", embFiles[batchIndex_],
- "normalization", opt<bool>("embedding-normalization"));
- }
- return New<Embedding>(graph_, options);
- }
-
- // ULR word embeddings
- /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
- return New<ULREmbedding>(graph_, New<Options>(
- "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
- "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
- "dimUlrEmb", opt<int>("ulr-dim-emb"),
- "dimEmb", opt<int>("dim-emb"),
- "ulr-dropout", opt<float>("ulr-dropout"),
- "dropout", dropoutEmbeddings_,
- "inference", inference_,
- "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
- "ulrQueryFile", opt<std::string>("ulr-query-vectors"),
- "ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
- }
-
- // get embedding layer for this encoder or decoder
- // This is lazy mostly because the constructors of the consuming objects are not
- // guaranteed presently to have access to their graph.
- Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
- if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
- if (embeddingLayers_.size() <= batchIndex_)
- embeddingLayers_.resize(batchIndex_ + 1);
- if (ulr)
- embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
- else
- embeddingLayers_[batchIndex_] = createEmbeddingLayer();
- }
- return embeddingLayers_[batchIndex_];
- }
-} // namespace marian
+namespace marian {} // namespace marian
diff --git a/src/layers/generic.h b/src/layers/generic.h
index f47bb45e..89f5c1e9 100755..100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -5,12 +5,14 @@
#include "data/shortlist.h"
#include "layers/factory.h"
-namespace marian { namespace mlp {
- /**
- * @brief Activation functions
- */
- enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
-}}
+namespace marian {
+namespace mlp {
+/**
+ * @brief Activation functions
+ */
+enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
+} // namespace mlp
+} // namespace marian
namespace marian {
@@ -23,8 +25,7 @@ protected:
Ptr<Options> options_;
public:
- LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : graph_(graph), options_(options) {}
+ LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) : graph_(graph), options_(options) {}
template <typename T>
T opt(const std::string key) const {
@@ -42,7 +43,7 @@ struct IUnaryLayer {
virtual ~IUnaryLayer() {}
virtual Expr apply(Expr) = 0;
virtual Expr apply(const std::vector<Expr>& es) {
- ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return apply(es.front());
}
};
@@ -54,7 +55,8 @@ struct IHasShortList {
// Embedding from corpus sub-batch to (emb, mask)
struct IEmbeddingLayer {
- virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0;
+ virtual std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
+ Ptr<data::SubBatch> subBatch) const = 0;
virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
@@ -63,28 +65,29 @@ struct IEmbeddingLayer {
virtual ~IEmbeddingLayer() {}
};
-// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index)
+// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream
+// index)
class EncoderDecoderLayerBase : public LayerBase {
protected:
const std::string prefix_;
const bool embeddingFix_;
- const float dropoutEmbeddings_; // this drops out full embedding vectors
+ const float dropoutEmbeddings_; // this drops out full embedding vectors
const bool inference_;
const size_t batchIndex_;
- mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
+ mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
- EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
- Ptr<Options> options,
- const std::string& prefix,
+ EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ const std::string& prefix,
size_t batchIndex,
float dropoutEmbeddings,
- bool embeddingFix) :
- LayerBase(graph, options),
- prefix_(options->get<std::string>("prefix", prefix)),
- embeddingFix_(embeddingFix),
- dropoutEmbeddings_(dropoutEmbeddings),
- inference_(options->get<bool>("inference", false)),
- batchIndex_(options->get<size_t>("index", batchIndex)) {}
+ bool embeddingFix)
+ : LayerBase(graph, options),
+ prefix_(options->get<std::string>("prefix", prefix)),
+ embeddingFix_(embeddingFix),
+ dropoutEmbeddings_(dropoutEmbeddings),
+ inference_(options->get<bool>("inference", false)),
+ batchIndex_(options->get<size_t>("index", batchIndex)) {}
virtual ~EncoderDecoderLayerBase() {}
@@ -97,78 +100,11 @@ public:
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
-class FactoredVocab;
-
-// To support factors, any output projection (that is followed by a softmax) must
-// retain multiple outputs, one for each factor. Such layer returns not a single Expr,
-// but a Logits object that contains multiple.
-// This allows to compute softmax values in a factored manner, where we never create
-// a fully expanded list of all factor combinations.
-class RationalLoss;
-class Logits {
-public:
- Logits() {}
- explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
- logits_.push_back(logits);
- }
- explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
- Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
- : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
- Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
- Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
- //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
- Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
- Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
- Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
-
- struct MaskedFactorIndices {
- std::vector<WordIndex> indices; // factor index, or 0 if masked
- std::vector<float> masks;
- void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
- void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
- MaskedFactorIndices() {}
- MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
- };
- std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
- Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
- size_t getNumFactorGroups() const { return logits_.size(); }
- bool empty() const { return logits_.empty(); }
- Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
-private:
- // helper functions
- Ptr<ExpressionGraph> graph() const;
- Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
- template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
- Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
- std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
-private:
- // members
- // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
- std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
- Ptr<FactoredVocab> factoredVocab_;
-};
-
-// Unary function that returns a Logits object
-// Also implements IUnaryLayer, since Logits can be cast to Expr.
-// This interface is implemented by all layers that are of the form of a unary function
-// that returns multiple logits, to support factors.
-struct IUnaryLogitLayer : public IUnaryLayer {
- virtual Logits applyAsLogits(Expr) = 0;
- virtual Logits applyAsLogits(const std::vector<Expr>& es) {
- ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
- return applyAsLogits(es.front());
- }
- virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
- virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
-};
-
namespace mlp {
class Dense : public LayerBase, public IUnaryLayer {
public:
- Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options) {}
+ Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {}
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@@ -190,21 +126,17 @@ public:
if(inputs.size() > 1)
num = std::to_string(i);
- Expr W = g->param(
- name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
+ Expr W = g->param(name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
if(useLayerNorm) {
if(useNematusNorm) {
- auto ln_s = g->param(
- name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
+ auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
- outputs.push_back(
- layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
+ outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
} else {
- auto gamma = g->param(
- name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
+ auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
outputs.push_back(layerNorm(dot(in, W), gamma, b));
}
@@ -231,241 +163,35 @@ public:
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
-} // namespace mlp
-
-class LSH;
-
-namespace mlp {
-
-class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
-private:
- // parameters held by this layer
- Expr Wt_; // weight matrix is stored transposed for efficiency
- Expr b_;
- Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
- bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
- bool hasBias_{true};
-
- Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
- Expr cachedShortb_; // these match the current value of shortlist_
- Expr cachedShortLemmaEt_;
- Ptr<FactoredVocab> factoredVocab_;
-
- // optional parameters set/updated after construction
- Expr tiedParam_;
- Ptr<data::Shortlist> shortlist_;
- Ptr<LSH> lsh_;
-
- void lazyConstruct(int inputDim);
-public:
- Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options),
- hasBias_{!options->get<bool>("output-omit-bias", false)} {
- clear();
- }
-
- void tieTransposed(Expr tied) {
- if (Wt_)
- ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created");
- else
- tiedParam_ = tied;
- }
-
- void setShortlist(Ptr<data::Shortlist> shortlist) override final {
- if (shortlist_)
- ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()");
- else {
- ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??");
- shortlist_ = shortlist;
- }
- // cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
- }
-
- // this is expected to be called in sync with graph->clear(), which invalidates
- // cachedShortWt_ etc. in the graph's short-term cache
- void clear() override final {
- shortlist_ = nullptr;
- cachedShortWt_ = nullptr;
- cachedShortb_ = nullptr;
- cachedShortLemmaEt_ = nullptr;
- }
-
- Logits applyAsLogits(Expr input) override final;
-};
-
} // namespace mlp
-// A regular embedding layer.
-// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
-// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
-// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
-class Embedding : public LayerBase, public IEmbeddingLayer {
- Expr E_;
- Ptr<FactoredVocab> factoredVocab_;
- Expr multiRows(const Words& data, float dropProb) const;
- bool inference_{false};
-
-public:
- Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
-
- std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
-
- Expr apply(const Words& words, const Shape& shape) const override final;
-
- Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
-};
-
-class ULREmbedding : public LayerBase, public IEmbeddingLayer {
- std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
- bool inference_{false};
-
-public:
- ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : LayerBase(graph, options), inference_(opt<bool>("inference")) {
- std::string name = "url_embed"; //opt<std::string>("prefix");
- int dimKeys = opt<int>("dimTgtVoc");
- int dimQueries = opt<int>("dimSrcVoc");
- int dimEmb = opt<int>("dimEmb");
- int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
- bool fixed = opt<bool>("fixed", false);
-
- // Embedding layer initialization should depend only on embedding size, hence fanIn=false
- auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);
-
- std::string queryFile = opt<std::string>("ulrQueryFile");
- std::string keyFile = opt<std::string>("ulrKeysFile");
- bool trainTrans = opt<bool>("ulrTrainTransform", false);
- if (!queryFile.empty() && !keyFile.empty()) {
- initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
- name = "ulr_query";
- fixed = true;
- auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
- ulrEmbeddings_.push_back(query_embed);
- // keys embeds
- initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
- name = "ulr_keys";
- fixed = true;
- auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
- ulrEmbeddings_.push_back(key_embed);
- // actual trainable embedding
- initFunc = inits::glorotUniform();
- name = "ulr_embed";
- fixed = false;
- auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
- ulrEmbeddings_.push_back(ulr_embed);
- // init trainable src embedding
- name = "ulr_src_embed";
- auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
- ulrEmbeddings_.push_back(ulr_src_embed);
- // ulr transformation matrix
- //initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
- if (trainTrans) {
- initFunc = inits::glorotUniform();
- fixed = false;
- }
- else
- {
- initFunc = inits::eye(); // identity matrix
- fixed = true;
- }
- name = "ulr_transform";
- auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
- ulrEmbeddings_.push_back(ulrTransform);
-
- initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
- fixed = true;
- name = "ulr_shared";
- auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
- ulrEmbeddings_.push_back(share_embed);
- }
- }
-
- std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
- auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
- auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
- auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
- auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
- auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
- auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
- int dimBatch = (int)subBatch->batchSize();
- int dimEmb = uniEmbed->shape()[-1];
- int dimWords = (int)subBatch->batchWidth();
- // D = K.A.QT
- // dimm(K) = univ_tok_vocab*uni_embed_size
- // dim A = uni_embed_size*uni_embed_size
- // dim Q: uni_embed_size * total_merged_vocab_size
- // dim D = univ_tok_vocab * total_merged_vocab_size
- // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
- // here we need to handle the mini-batch
- // extract raws corresponding to Xs in this minibatch from Q
- auto embIdx = toWordIndexVector(subBatch->data());
- auto queryEmbeddings = rows(queryEmbed, embIdx);
- auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
- auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
- auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
- auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
- qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
- auto z = dot(qt, keyEmbed, false, true); // query-key similarity
- float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
- if(!inference_)
- z = dropout(z, dropProb);
-
- float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
- // temperature in softmax is to control randomness of predictions
- // high temperature Softmax outputs are more close to each other
- // low temperatures the softmax become more similar to "hardmax"
- auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
- auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
- auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
- auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
- auto graph = ulrEmbeddings_.front()->graph();
- auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
- inits::fromVector(subBatch->mask()));
- if(!inference_)
- batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout-embeddings", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
- return std::make_tuple(batchEmbeddings, batchMask);
- }
-
- Expr apply(const Words& words, const Shape& shape) const override final {
- return applyIndices(toWordIndexVector(words), shape);
- }
-
- Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
- embIdx; shape;
- ABORT("not implemented"); // @TODO: implement me
- }
-};
-
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
// like affine() but with built-in parameters, activation, and dropout
-static inline
-Expr denseInline(Expr x,
- std::string prefix,
- std::string suffix,
- int outDim,
- Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
- const std::function<Expr(Expr)>& actFn = nullptr,
- float dropProb = 0.0f)
-{
+static inline Expr denseInline(Expr x,
+ std::string prefix,
+ std::string suffix,
+ int outDim,
+ Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
+ const std::function<Expr(Expr)>& actFn = nullptr,
+ float dropProb = 0.0f) {
auto graph = x->graph();
- auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform());
- auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros());
+ auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
+ auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
x = affine(x, W, b);
- if (actFn)
+ if(actFn)
x = actFn(x);
- x = dropout(x, dropProb); // @TODO: check for infernce?
+ x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}
-static inline
-Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
+static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
- auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones());
- auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros());
+ auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, {1, dimModel}, inits::ones());
+ auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, {1, dimModel}, inits::zeros());
return marian::layerNorm(x, scale, bias, 1e-6f);
}
diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h
index f08d3f09..f08d3f09 100755..100644
--- a/src/layers/guided_alignment.h
+++ b/src/layers/guided_alignment.h
diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp
new file mode 100644
index 00000000..8c4d69bd
--- /dev/null
+++ b/src/layers/logits.cpp
@@ -0,0 +1,245 @@
+#include "logits.h"
+#include "data/factored_vocab.h"
+#include "loss.h"
+#include "rnn/types.h" // for State::select()
+
+namespace marian {
+Logits::Logits(Expr logits)
+ : Logits(New<RationalLoss>(logits, nullptr)) {
+} // single-output constructor from Expr only (RationalLoss has no count)
+
+Ptr<ExpressionGraph> Logits::graph() const {
+ ABORT_IF(logits_.empty(), "Empty logits object??");
+ return logits_.front()->loss()->graph();
+}
+
+// This function assumes that the object holds one or more factor logits.
+// It applies the supplied loss function to each, and then returns the aggregate loss over all
+// factors.
+Expr Logits::applyLossFunction(
+ const Words& labels,
+ const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const {
+ LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size());
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+
+ auto firstLogits = logits_.front()->loss();
+ ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
+ "Labels not matching logits shape ({} != {}, {})??",
+ labels.size() * firstLogits->shape()[-1],
+ firstLogits->shape().elements(),
+ firstLogits->shape());
+
+ // base case (no factors)
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return lossFn(firstLogits, indices(toWordIndexVector(labels)));
+ }
+
+ auto numGroups = factoredVocab_->getNumGroups();
+
+ // split labels into individual factor labels
+ auto allMaskedFactoredLabels
+ = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
+
+ // Expr indices = this->indices(toWordIndexVector(labels));
+ // accumulate all CEs for all words that have the factor
+ // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
+ Expr loss;
+ for(size_t g = 0; g < numGroups; g++) {
+ if(!logits_[g])
+ continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
+ // clang-format off
+ const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
+ auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
+ auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
+ auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
+ // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
+ auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
+ // clang-format on
+ if(loss)
+ factorLoss = cast(factorLoss, loss->value_type());
+ factorLoss
+ = factorLoss
+ * cast(
+ reshape(factorMask, factorLoss->shape()),
+ factorLoss->value_type()); // mask out factor for words that do not have that factor
+ loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
+ }
+ return loss;
+}
+
+// This function assumes this object holds a single factor that represents a rational loss (with
+// count).
+// Ptr<RationalLoss> Logits::getRationalLoss() const {
+// ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on
+// multi-factor outputs"); ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational
+// loss without count"); return logits_.front();
+//}
+
+// get logits for one factor group
+// For groupIndex == 0, the function also requires the shortlist if there is one.
+Expr Logits::getFactoredLogits(size_t groupIndex,
+ Ptr<data::Shortlist> shortlist /*= nullptr*/,
+ const std::vector<IndexType>& hypIndices /*= {}*/,
+ size_t beamSize /*= 0*/) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+
+ auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
+
+ // normalize for decoding:
+ // - all secondary factors: subtract their max
+ // - lemma: add all maxes of applicable factors
+ if(groupIndex > 0) {
+ sel = sel - max(sel, -1);
+ } else {
+ auto numGroups = getNumFactorGroups();
+ for(size_t g = 1; g < numGroups; g++) {
+ auto factorMaxima = max(logits_[g]->loss(),
+ -1); // we cast since loss is likely ce-loss which has type float32
+ auto factorMasks = constant(
+ getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
+ sel = sel
+ + cast(factorMaxima, sel->value_type())
+ * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor
+ // get multiplied with 0
+ }
+ }
+
+ // if selIdx are given, then we must reshuffle accordingly
+ if(!hypIndices.empty()) // use the same function that shuffles decoder state
+ sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
+
+ return sel;
+}
+
+// used for breakDown() only
+// Index is flattened
+Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ return logits_[groupIndex]->loss()->val();
+}
+
+// This function assumes that the object holds one or more factor logits, which are summed up
+// into output-vocab logits according to the factored model (with correct normalization of factors).
+// This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
+// @TODO: remove altogether
+Expr Logits::getLogits() const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return getFactoredLogits(0);
+ }
+
+#ifdef FACTOR_FULL_EXPANSION
+ // compute normalized factor log probs
+ std::vector<Expr> logProbs(logits_.size());
+ for(size_t g = 0; g < logits_.size(); g++)
+ logProbs[g] = logsoftmax(logits_[g]->loss());
+ auto y = concatenate(logProbs, /*axis=*/-1);
+
+ // clang-format off
+ // sum up the unit logits across factors for each target word
+ auto graph = y->graph();
+ auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
+ y = dot_csr(
+ y, // [B x U]
+ factorMatrix.shape,
+ graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)),
+ graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
+ graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
+ /*transB=*/true); // -> [B x V]
+ // clang-format on
+
+ // mask out gaps
+ auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
+ y = y + graph->constant({(int)gapLogMask.size()}, inits::fromVector(gapLogMask));
+
+ return y;
+#else
+ ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
+#endif
+}
+
+void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
+ bool isValid = FactoredVocab::isFactorValid(factorIndex);
+ indices.push_back(isValid ? (WordIndex)factorIndex : 0);
+ masks.push_back((float)isValid);
+}
+
+std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words)
+ const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
+ if(!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return {MaskedFactorIndices(words)};
+ }
+ auto numGroups = factoredVocab_->getNumGroups();
+ std::vector<MaskedFactorIndices> res(numGroups);
+ for(size_t g = 0; g < numGroups; g++) {
+ auto& resg = res[g];
+ resg.reserve(words.size());
+ for(const auto& word : words)
+ resg.push_back(factoredVocab_->getFactor(word, g));
+ }
+ return res;
+}
+
+//// use first factor of each word to determine whether it has a specific factor
+// std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0
+// for words that do have this factor; else 0
+// std::vector<float> res;
+// res.reserve(words.size());
+// for (const auto& word : words) {
+// auto lemma = factoredVocab_->getFactor(word, 0);
+// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+// }
+// return res;
+//}
+
+// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
+// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
+std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices)
+ const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
+ size_t n
+ = indices.empty()
+ ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
+ : indices.size();
+ std::vector<float> res;
+ res.reserve(n);
+ // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
+ // into FactoredVocab
+ for(size_t i = 0; i < n; i++) {
+ auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
+ res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+ }
+ return res;
+}
+
+Logits Logits::applyUnaryFunction(
+ const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for(const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+
+Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
+ const std::function<Expr(Expr)>& fother) const {
+ std::vector<Ptr<RationalLoss>> newLogits;
+ bool first = true;
+ for(const auto& l : logits_) {
+ newLogits.emplace_back(New<RationalLoss>((first ? f1 : fother)(l->loss()),
+ l->count())); // f1 for first, fother for all others
+ first = false;
+ }
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+
+// @TODO: code dup with above; we can merge it into applyToRationalLoss()
+Logits Logits::withCounts(
+ const Expr& count) const { // create new Logits with 'count' implanted into all logits_
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for(const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
+ return Logits(std::move(newLogits), factoredVocab_);
+}
+} // namespace marian
diff --git a/src/layers/logits.h b/src/layers/logits.h
new file mode 100644
index 00000000..c61a9e74
--- /dev/null
+++ b/src/layers/logits.h
@@ -0,0 +1,106 @@
+#pragma once
+
+#include "data/shortlist.h"
+#include "generic.h"
+#include "marian.h"
+
+namespace marian {
+
+class FactoredVocab;
+
+// To support factors, any output projection (that is followed by a softmax) must
+// retain multiple outputs, one for each factor. Such layer returns not a single Expr,
+// but a Logits object that contains multiple.
+// This allows to compute softmax values in a factored manner, where we never create
+// a fully expanded list of all factor combinations.
+class RationalLoss;
+class Logits {
+public:
+ Logits() {}
+ explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
+ logits_.push_back(logits);
+ }
+ explicit Logits(
+ Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
+ Logits(std::vector<Ptr<RationalLoss>>&& logits,
+ Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
+ : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
+ Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
+ Expr getFactoredLogits(
+ size_t groupIndex,
+ Ptr<data::Shortlist> shortlist = nullptr,
+ const std::vector<IndexType>& hypIndices = {},
+ size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
+ // Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
+ Expr applyLossFunction(
+ const Words& labels,
+ const std::function<Expr(Expr /*logits*/, Expr /*indices*/)>& lossFn) const;
+ Logits applyUnaryFunction(
+ const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
+ Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1,
+ const std::function<Expr(Expr)>& fother)
+ const; // clone this but apply f1 to first and fother to to all other values
+
+ struct MaskedFactorIndices {
+ std::vector<WordIndex> indices; // factor index, or 0 if masked
+ std::vector<float> masks;
+ void reserve(size_t n) {
+ indices.reserve(n);
+ masks.reserve(n);
+ }
+ void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0
+ // for invalid entries
+ MaskedFactorIndices() {}
+ MaskedFactorIndices(const Words& words) {
+ indices = toWordIndexVector(words);
+ } // we can leave masks uninitialized for this special use case
+ };
+ std::vector<MaskedFactorIndices> factorizeWords(
+ const Words& words) const; // breaks encoded Word into individual factor indices
+ Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
+ size_t getNumFactorGroups() const { return logits_.size(); }
+ bool empty() const { return logits_.empty(); }
+ Logits withCounts(
+ const Expr& count) const; // create new Logits with 'count' implanted into all logits_
+private:
+ // helper functions
+ Ptr<ExpressionGraph> graph() const;
+ Expr constant(const Shape& shape, const std::vector<float>& data) const {
+ return graph()->constant(shape, inits::fromVector(data));
+ }
+ Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const {
+ return graph()->constant(shape, inits::fromVector(data));
+ }
+ template <typename T>
+ Expr constant(const std::vector<T>& data) const {
+ return constant(Shape{(int)data.size()}, data);
+ } // same as constant() but assuming vector
+ Expr indices(const std::vector<uint32_t>& data) const {
+ return graph()->indices(data);
+ } // actually the same as constant(data) for this data type
+ std::vector<float> getFactorMasks(size_t factorGroup,
+ const std::vector<WordIndex>& indices) const;
+
+private:
+ // members
+ // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just
+ // by the Expr
+ std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
+ Ptr<FactoredVocab> factoredVocab_;
+};
+
+// Unary function that returns a Logits object
+// Also implements IUnaryLayer, since Logits can be cast to Expr.
+// This interface is implemented by all layers that are of the form of a unary function
+// that returns multiple logits, to support factors.
+struct IUnaryLogitLayer : public IUnaryLayer {
+ virtual Logits applyAsLogits(Expr) = 0;
+ virtual Logits applyAsLogits(const std::vector<Expr>& es) {
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ return applyAsLogits(es.front());
+ }
+ virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
+ virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
+};
+
+} // namespace marian
diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp
index 67d38832..695276af 100755..100644
--- a/src/layers/loss.cpp
+++ b/src/layers/loss.cpp
@@ -13,26 +13,30 @@ Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
bool wordScores = options->get<bool>("word-scores", false);
return New<RescorerLoss>(wordScores);
} else if(unlikelihood) {
- ABORT_IF(!options->hasAndNotEmpty("data-weighting")
- && options->get<std::string>("data-weighting-type") != "word",
- "Unlikelihood loss training requires error annotation in form of per-target-label scores");
- return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
- } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
+ ABORT_IF(
+ !options->hasAndNotEmpty("data-weighting")
+ && options->get<std::string>("data-weighting-type") != "word",
+ "Unlikelihood loss training requires error annotation in form of per-target-label scores");
+ return New<SequenceUnlikelihoodLoss>(
+ smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on
+ // values given for data-weighting
+ } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones.
+ // E.g. what about ce-sum?
return New<CrossEntropyLoss>(smoothing, factorWeight);
}
}
// see loss.h for detailed explanations of each class
Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
- std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
- if(multiLossType == "sum") // sum of sums
- return New<SumMultiRationalLoss>();
- else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
- return New<ScaledMultiRationalLoss>();
- else if(multiLossType == "mean") // sum of means
- return New<MeanMultiRationalLoss>();
- else
- ABORT("Unknown multi-loss-type {}", multiLossType);
+ std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
+ if(multiLossType == "sum") // sum of sums
+ return New<SumMultiRationalLoss>();
+ else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
+ return New<ScaledMultiRationalLoss>();
+ else if(multiLossType == "mean") // sum of means
+ return New<MeanMultiRationalLoss>();
+ else
+ ABORT("Unknown multi-loss-type {}", multiLossType);
}
} // namespace marian
diff --git a/src/layers/loss.h b/src/layers/loss.h
index d7bc19e4..c662f991 100755..100644
--- a/src/layers/loss.h
+++ b/src/layers/loss.h
@@ -1,8 +1,8 @@
#pragma once
-#include "graph/expression_operators.h"
-#include "layers/generic.h" // for Logits (Frank's factor hack)
#include "data/types.h"
+#include "graph/expression_operators.h"
+#include "layers/logits.h" // for Logits (Frank's factor hack)
namespace marian {
@@ -22,21 +22,18 @@ namespace marian {
*/
class RationalLoss {
protected:
- Expr loss_; // numerator
- Expr count_; // denominator
+ Expr loss_; // numerator
+ Expr count_; // denominator
- RationalLoss() = default; // protected
+ RationalLoss() = default; // protected
public:
- RationalLoss(Expr loss, Expr count)
- : loss_(loss), count_(count) {}
+ RationalLoss(Expr loss, Expr count) : loss_(loss), count_(count) {}
RationalLoss(Expr loss, float count)
- : loss_(loss),
- count_(constant_like(loss, inits::fromValue(count))) {}
+ : loss_(loss), count_(constant_like(loss, inits::fromValue(count))) {}
- RationalLoss(const RationalLoss& other)
- : loss_(other.loss_), count_(other.count_) {}
+ RationalLoss(const RationalLoss& other) : loss_(other.loss_), count_(other.count_) {}
virtual ~RationalLoss() = default;
@@ -50,7 +47,7 @@ public:
}
template <typename T>
- T loss() const { // this will fail if loss is not a single value
+ T loss() const { // this will fail if loss is not a single value
ABORT_IF(!loss_, "Loss has not been defined");
return loss_->val()->scalar<T>();
}
@@ -65,7 +62,7 @@ public:
}
template <typename T>
- T count() const { // this will fail if loss is not a single value
+ T count() const { // this will fail if loss is not a single value
ABORT_IF(!count_, "Labels have not been defined");
return count_->val()->scalar<T>();
}
@@ -85,21 +82,21 @@ public:
* RationalLoss object.
*/
struct StaticLoss {
- float loss; // numerator
- float count; // denominator
+ float loss; // numerator
+ float count; // denominator
StaticLoss() : loss(0.f), count(0.f) {}
StaticLoss(const RationalLoss& dynamic)
- : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
+ : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
- StaticLoss operator +(const StaticLoss& other) const {
+ StaticLoss operator+(const StaticLoss& other) const {
StaticLoss res(*this);
res += other;
return res;
}
- StaticLoss& operator +=(const StaticLoss& other) {
+ StaticLoss& operator+=(const StaticLoss& other) {
loss = loss + other.loss;
count = count + other.count;
return *this;
@@ -139,32 +136,21 @@ protected:
public:
MultiRationalLoss() : RationalLoss() {}
- MultiRationalLoss(const RationalLoss& rl) : RationalLoss() {
- push_back(rl);
- }
+ MultiRationalLoss(const RationalLoss& rl) : RationalLoss() { push_back(rl); }
virtual void push_back(const RationalLoss& current) {
- loss_ = accumulateLoss(current);
- count_ = accumulateCount(current);
+ loss_ = accumulateLoss(current);
+ count_ = accumulateCount(current);
partialLosses_.push_back(current);
}
- const RationalLoss& operator[](size_t i) {
- return partialLosses_[i];
- }
+ const RationalLoss& operator[](size_t i) { return partialLosses_[i]; }
- auto begin() -> decltype(partialLosses_.begin()) const {
- return partialLosses_.begin();
- }
+ auto begin() -> decltype(partialLosses_.begin()) const { return partialLosses_.begin(); }
- auto end() -> decltype(partialLosses_.end()) const {
- return partialLosses_.end();
- }
-
- size_t size() const {
- return partialLosses_.size();
- }
+ auto end() -> decltype(partialLosses_.end()) const { return partialLosses_.end(); }
+ size_t size() const { return partialLosses_.size(); }
};
/**
@@ -212,17 +198,19 @@ private:
virtual Expr accumulateLoss(const RationalLoss& current) override {
if(loss_) {
const auto& first = partialLosses_.front();
- return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss
+ return loss_
+ + current.loss() * first.count()
+ / current.count(); // scale up/down to match scale of first loss
} else {
- return current.loss(); // first reference loss, keeps to scale with this one
+ return current.loss(); // first reference loss, keeps to scale with this one
}
}
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_) {
- return count_; // Keep first label count // or: count_ + first.count() / current.count();
+ return count_; // Keep first label count // or: count_ + first.count() / current.count();
} else {
- return current.count(); // This is the first loss
+ return current.count(); // This is the first loss
}
}
@@ -253,9 +241,10 @@ private:
virtual Expr accumulateCount(const RationalLoss& current) override {
if(count_)
- return count_; // keep the existing '1'
+ return count_; // keep the existing '1'
else
- return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
+ return current.count()->graph()->ones(
+ {1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
}
public:
@@ -279,18 +268,21 @@ class LabelwiseLoss {
protected:
std::vector<int> axes_;
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr)
+ = 0;
// label counts are available, reduce together with loss to obtain counts
RationalLoss reduce(Expr loss, Expr labels) {
ABORT_IF(!loss, "Loss has not been computed");
ABORT_IF(!labels, "Labels have not been computed");
- Expr lossSum = cast(loss, Type::float32); // accumulate in float32
- Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
+ Expr lossSum = cast(loss, Type::float32); // accumulate in float32
+ Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
for(int i = 0; i < axes_.size(); ++i) {
- lossSum = sum(lossSum, axes_[i]);
+ lossSum = sum(lossSum, axes_[i]);
labelsSum = sum(labelsSum, axes_[i]);
}
@@ -301,7 +293,7 @@ protected:
RationalLoss reduce(Expr loss) {
ABORT_IF(!loss, "Loss has not been computed");
- Expr lossSum = cast(loss, Type::float32);
+ Expr lossSum = cast(loss, Type::float32);
for(int i = 0; i < axes_.size(); ++i)
lossSum = sum(lossSum, axes_[i]);
@@ -311,17 +303,18 @@ protected:
}
public:
- LabelwiseLoss(const std::vector<int>& axes)
- : axes_(axes) { }
+ LabelwiseLoss(const std::vector<int>& axes) : axes_(axes) {}
- virtual RationalLoss apply(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) {
+ virtual RationalLoss apply(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) {
Expr loss = compute(logits, labels, mask, labelWeights);
if(mask)
- return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
+ return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
else
- return reduce(loss); // we have no mask, assume all items are labels
+ return reduce(loss); // we have no mask, assume all items are labels
}
};
@@ -331,28 +324,34 @@ public:
class CrossEntropyLoss : public LabelwiseLoss {
public:
CrossEntropyLoss(float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+ : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {
+ } // cross-entropy already reduces over axis -1
CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
- : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
- labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {}
+ : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
+ labelSmoothing_(labelSmoothing),
+ factorWeight_(factorWeight) {}
virtual ~CrossEntropyLoss() {}
-protected:
- float labelSmoothing_; // interpolation factor for label smoothing, see below
- float factorWeight_; // give extra weight to factors
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) override {
- // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
+protected:
+ float labelSmoothing_; // interpolation factor for label smoothing, see below
+ float factorWeight_; // give extra weight to factors
+
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) override {
+ // logits may be factored; in that case, the getLoss() function computes one loss for each, and
+ // sums them up
int inFactor = false;
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
- logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
+ logits = atleast_3d(logits); // we always assume a time and batch dimension exists.
// for bert training or classification the time dimension is lost.
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
-
+
Expr ce = cross_entropy(logits, indices, inFactor ? 0.f : labelSmoothing_, Type::float32);
- if (inFactor && factorWeight_ != 1.0f) {
+ if(inFactor && factorWeight_ != 1.0f) {
LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
ce = ce * factorWeight_;
}
@@ -365,8 +364,10 @@ protected:
if(labelWeights) {
// We currently do not know how to use target factors and word-level label weights together
- bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
- ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
+ bool wordlevel = labelWeights->shape()[-3]
+ > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
+ ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1,
+ "CE loss with word-level label weights is not implemented for factors");
ce = ce * cast(labelWeights, Type::float32);
}
@@ -374,13 +375,12 @@ protected:
}
};
-
/**
* @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
* implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
- * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
- * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
- * to flip over to use SUL for that sentence to penalize the selected word.
+ * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are
+ * not zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it
+ * is going to flip over to use SUL for that sentence to penalize the selected word.
*
* SUL is implemented as:
* -log(gather(1 - softmax(logits), -1, indices))
@@ -390,35 +390,45 @@ protected:
class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+ : CrossEntropyLoss(labelSmoothing, factorWeight) {
+ } // cross-entropy already reduces over axis -1
SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
- : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
+ : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
protected:
- virtual Expr compute(Logits logits, const Words& labels,
- Expr mask = nullptr, Expr labelWeights = nullptr) override {
- auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
+ virtual Expr compute(Logits logits,
+ const Words& labels,
+ Expr mask = nullptr,
+ Expr labelWeights = nullptr) override {
+ auto ce = CrossEntropyLoss::compute(
+ logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
if(!labelWeights)
- return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
+ return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
// We currently do not know how to use target factors and word-level label weights together
ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
- ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
- // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask again to eliminate padding (might be obsolete)
+ ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by
+ // default 1, which would make this obsolete.
+ // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks BAD, mask
+ // again to eliminate padding (might be obsolete)
auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
return cast(unlikelihood(logits, indices), Type::float32);
});
- // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
- // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
- auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
- ceUl = errorMask * ceUl; // don't use for correct label or padding
+ // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only
+ // on_ the errors with UL. This is the "mixed" training schedule from
+ // https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily
+ // switch between CE and UL.
+ auto onlyCe = eq(sum(errorMask, /*axis=*/-3),
+ 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
+ ceUl = errorMask * ceUl; // don't use for correct label or padding
- auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
+ auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never
+ // simultanously used as cost per batch entry
return cost;
}
@@ -463,7 +473,6 @@ public:
}
};
-
/**
* @brief Factory for label-wise loss functions
*/
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
new file mode 100644
index 00000000..1d9c7b4b
--- /dev/null
+++ b/src/layers/output.cpp
@@ -0,0 +1,293 @@
+#include "output.h"
+#include "common/timer.h"
+#include "data/factored_vocab.h"
+#include "layers/loss.h"
+#include "layers/lsh.h"
+
+namespace marian {
+namespace mlp {
+
+/*private*/ void Output::lazyConstruct(int inputDim) {
+ // We must construct lazily since we won't know tying nor input dim in constructor.
+ if(Wt_)
+ return;
+
+ // this option is only set in the decoder
+ if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) {
+ auto k = opt<std::vector<int>>("output-approx-knn")[0];
+ auto nbits = opt<std::vector<int>>("output-approx-knn")[1];
+ lsh_ = New<LSH>(k, nbits);
+ }
+
+ auto name = options_->get<std::string>("prefix");
+ auto numOutputClasses = options_->get<int>("dim");
+
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if(factoredVocab_) {
+ numOutputClasses = (int)factoredVocab_->factorVocabSize();
+ LOG_ONCE(info, "[embedding] Factored outputs enabled");
+ }
+
+ if(tiedParam_) {
+ Wt_ = tiedParam_;
+ } else {
+ if(graph_->get(name + "_W")) { // support of legacy models that did not transpose
+ Wt_ = graph_->param(
+ name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
+ isLegacyUntransposedW = true;
+ } else // this is the regular case:
+ Wt_ = graph_->param(
+ name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
+ }
+
+ if(hasBias_)
+ b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
+
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
+ if(lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
+#define HARDMAX_HACK
+#ifdef HARDMAX_HACK
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
+#endif
+ auto range = factoredVocab_->getGroupRange(0);
+ auto lemmaVocabDim = (int)(range.second - range.first);
+ auto initFunc = inits::glorotUniform(
+ /*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
+ lemmaEt_ = graph_->param(name + "_lemmaEt",
+ {lemmaDimEmb, lemmaVocabDim},
+ initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
+ }
+}
+
+Logits Output::applyAsLogits(Expr input) /*override final*/ {
+ lazyConstruct(input->shape()[-1]);
+
+ auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ if(b)
+ return affine(x, W, b, transA, transB);
+ else
+ return dot(x, W, transA, transB);
+ };
+
+ auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
+ if(lsh_) {
+ ABORT_IF(transA, "Transposed query not supported for LSH");
+ ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH");
+ return lsh_->apply(x, W, b); // knows how to deal with undefined bias
+ } else {
+ return affineOrDot(x, W, b, transA, transB);
+ }
+ };
+
+ if(shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one
+ // batch, then clear()ed
+ cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
+ if(hasBias_)
+ cachedShortb_ = index_select(b_, -1, shortlist_->indices());
+ }
+
+ if(factoredVocab_) {
+ auto graph = input->graph();
+
+ // project each factor separately
+ auto numGroups = factoredVocab_->getNumGroups();
+ std::vector<Ptr<RationalLoss>> allLogits(numGroups,
+ nullptr); // (note: null entries for absent factors)
+ Expr input1 = input; // [B... x D]
+ Expr Plemma = nullptr; // used for lemmaDimEmb=-1
+ Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
+ for(size_t g = 0; g < numGroups; g++) {
+ auto range = factoredVocab_->getGroupRange(g);
+ if(g > 0 && range.first == range.second) // empty entry
+ continue;
+ ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second,
+ "Factor groups must be consecutive (group {} vs predecessor)",
+ g);
+ // slice this group's section out of W_
+ Expr factorWt, factorB;
+ if(g == 0 && shortlist_) {
+ factorWt = cachedShortWt_;
+ factorB = cachedShortb_;
+ } else {
+ factorWt = slice(
+ Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
+ if(hasBias_)
+ factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
+ }
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
+ && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
+ // this mimics one transformer layer
+ // - attention over two inputs:
+ // - e = current lemma. We use the original embedding vector; specifically, expectation
+ // over all lemmas.
+ // - input = hidden state FF(h_enc+h_dec)
+ // - dot-prod attention to allow both sides to influence (unlike our recurrent
+ // self-attention)
+ // - multi-head to allow for multiple conditions to be modeled
+ // - add & norm, for gradient flow and scaling
+ // - FF layer --this is expensive; it is per-factor
+ // multi-head attention
+ int inputDim = input->shape()[-1];
+ int heads = 8;
+ auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
+ auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform());
+ auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform());
+ auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform());
+ auto toMultiHead = [&](Expr x, int heads) {
+ const auto& shape = x->shape();
+ int inputDim = shape[-1];
+ int otherDim = shape.elements() / inputDim;
+ ABORT_IF(inputDim / heads * heads != inputDim,
+ "inputDim ({}) must be multiple of number of heads ({})",
+ inputDim,
+ heads);
+ return reshape(x, {otherDim, heads, 1, inputDim / heads});
+ };
+ input1 = inputLemma;
+ auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
+ auto kdm = toMultiHead(dot(input1 - input, Wk),
+ heads); // [B... x H x D/H] the two data vectors projected as keys.
+ // Use diff and sigmoid, instead of softmax.
+ auto vem = toMultiHead(
+ dot(input1, Wv),
+ heads); // [B... x H x D/H] one of the two data vectors projected as values
+ auto vim = toMultiHead(dot(input, Wv), heads); // [B... x H x D/H] the other
+ auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
+ auto sm = sigmoid(zm); // [B... x H x 1]
+ auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
+ auto r = reshape(rm, input->shape()); // [B... x D]
+ // add & norm
+ input1 = r + input1;
+ input1 = layerNorm(input1, name + "_att");
+ // FF layer
+ auto ffnDropProb = 0.1f; // @TODO: get as a parameter
+ auto ffnDim = inputDim * 2; // @TODO: get as a parameter
+ auto f = denseInline(input1,
+ name + "_ffn",
+ /*suffix=*/"1",
+ ffnDim,
+ inits::glorotUniform(),
+ (ActivationFunction*)relu,
+ ffnDropProb);
+ f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
+ // add & norm
+ input1 = f + input1;
+ input1 = layerNorm(input1, name + "_ffn");
+ }
+ // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
+ // matrix
+ Expr factorLogits;
+ if(g == 0)
+ factorLogits = affineOrLSH(
+ input1,
+ factorWt,
+ factorB,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
+ else
+ factorLogits = affineOrDot(
+ input1,
+ factorWt,
+ factorB,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
+
+ // optionally add lemma-dependent bias
+ if(Plemma) { // [B... x U0]
+ int lemmaVocabDim = Plemma->shape()[-1];
+ int factorVocabDim = factorLogits->shape()[-1];
+ auto name = options_->get<std::string>("prefix");
+ Expr lemmaBt
+ = graph_->param(name + "_lemmaBt_" + std::to_string(g),
+ {factorVocabDim, lemmaVocabDim},
+ inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
+ auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
+ factorLogits = factorLogits + b;
+ }
+ allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
+ // optionally add a soft embedding of lemma back to create some lemma dependency
+ // @TODO: if this works, move it into lazyConstruct
+ if(lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
+ // get expected lemma embedding vector
+ auto factorLogSoftmax = logsoftmax(
+ factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto factorSoftmax = exp(factorLogSoftmax);
+ inputLemma = dot(factorSoftmax,
+ factorWt,
+ false,
+ /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ } else if(lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
+ // get max-lemma embedding vector
+ auto maxVal = max(factorLogits,
+ -1); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto factorHardmax = eq(factorLogits, maxVal);
+ inputLemma = dot(factorHardmax,
+ factorWt,
+ false,
+ /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ } else if(lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
+ ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
+ LOG_ONCE(info, "[embedding] using lemma-dependent bias");
+ auto factorLogSoftmax
+ = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
+ auto z = /*stopGradient*/ (factorLogSoftmax);
+ Plemma = exp(z); // [B... x U]
+ } else if(lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
+ LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
+ // compute softmax. We compute logsoftmax() separately because this way, computation will be
+ // reused later via CSE
+ auto factorLogSoftmax = logsoftmax(factorLogits);
+ auto factorSoftmax = exp(factorLogSoftmax);
+#ifdef HARDMAX_HACK
+ bool hardmax = (lemmaDimEmb & 1)
+ != 0; // odd value triggers hardmax for now (for quick experimentation)
+ if(hardmax) {
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
+ LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
+ auto maxVal = max(factorSoftmax, -1);
+ factorSoftmax = eq(factorSoftmax, maxVal);
+ }
+#endif
+ // re-embedding lookup, soft-indexed by softmax
+ if(shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
+ cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
+ auto e = dot(factorSoftmax,
+ cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_,
+ false,
+ true); // [B... x L]
+ // project it back to regular hidden dim
+ int inputDim = input1->shape()[-1];
+ auto name = options_->get<std::string>("prefix");
+ // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also
+ // length 1
+ Expr lemmaWt
+ = inputDim == lemmaDimEmb
+ ? nullptr
+ : graph_->param(name + "_lemmaWt",
+ {inputDim, lemmaDimEmb},
+ inits::glorotUniform()); // [D x L] D=hidden-vector dimension
+ auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
+ // augment the original hidden vector with this additional information
+ input1 = input1 + f;
+ }
+ }
+ return Logits(std::move(allLogits), factoredVocab_);
+ } else if(shortlist_) {
+ return Logits(affineOrLSH(input,
+ cachedShortWt_,
+ cachedShortb_,
+ false,
+ /*transB=*/isLegacyUntransposedW ? false : true));
+ } else {
+ return Logits(
+ affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
+ }
+}
+
+} // namespace mlp
+} // namespace marian \ No newline at end of file
diff --git a/src/layers/output.h b/src/layers/output.h
new file mode 100644
index 00000000..2b6f4986
--- /dev/null
+++ b/src/layers/output.h
@@ -0,0 +1,75 @@
+#pragma once
+
+#include "data/shortlist.h"
+#include "generic.h"
+#include "layers/factory.h"
+#include "logits.h"
+#include "marian.h"
+
+namespace marian {
+class LSH;
+
+namespace mlp {
+
+class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
+private:
+ // parameters held by this layer
+ Expr Wt_; // weight matrix is stored transposed for efficiency
+ Expr b_;
+ Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
+ bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
+ bool hasBias_{true};
+
+ Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
+ Expr cachedShortb_; // these match the current value of shortlist_
+ Expr cachedShortLemmaEt_;
+ Ptr<FactoredVocab> factoredVocab_;
+
+ // optional parameters set/updated after construction
+ Expr tiedParam_;
+ Ptr<data::Shortlist> shortlist_;
+ Ptr<LSH> lsh_;
+
+ void lazyConstruct(int inputDim);
+
+public:
+ Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options), hasBias_{!options->get<bool>("output-omit-bias", false)} {
+ clear();
+ }
+
+ void tieTransposed(Expr tied) {
+ if(Wt_)
+ ABORT_IF(tiedParam_.get() != tied.get(),
+ "Tied output projection cannot be changed once weights have been created");
+ else
+ tiedParam_ = tied;
+ }
+
+ void setShortlist(Ptr<data::Shortlist> shortlist) override final {
+ if(shortlist_)
+ ABORT_IF(shortlist.get() != shortlist_.get(),
+ "Output shortlist cannot be changed except after clear()");
+ else {
+ ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_,
+ "No shortlist but cached parameters??");
+ shortlist_ = shortlist;
+ }
+ // cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
+ }
+
+ // this is expected to be called in sync with graph->clear(), which invalidates
+ // cachedShortWt_ etc. in the graph's short-term cache
+ void clear() override final {
+ shortlist_ = nullptr;
+ cachedShortWt_ = nullptr;
+ cachedShortb_ = nullptr;
+ cachedShortLemmaEt_ = nullptr;
+ }
+
+ Logits applyAsLogits(Expr input) override final;
+};
+
+} // namespace mlp
+
+} // namespace marian
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp
index 6476df8f..6476df8f 100755..100644
--- a/src/microsoft/quicksand.cpp
+++ b/src/microsoft/quicksand.cpp
diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h
index 87de1948..87de1948 100755..100644
--- a/src/microsoft/quicksand.h
+++ b/src/microsoft/quicksand.h
diff --git a/src/microsoft/shortlist/logging/LoggerMacros.h b/src/microsoft/shortlist/logging/LoggerMacros.h
new file mode 100644
index 00000000..ca74e737
--- /dev/null
+++ b/src/microsoft/shortlist/logging/LoggerMacros.h
@@ -0,0 +1,25 @@
+#pragma once
+
+// Do NOT include this file directly except in special circumstances.
+// (E.g., you want to define macros which call these but don't want to include Logger.h everywhere).
+// Normally you should include logging/Logger.h
+
+#define LOG_WRITE(format, ...) do {\
+ abort(); \
+} while (0)
+
+#define LOG_WRITE_STRING(str) do {\
+ abort(); \
+} while (0)
+
+#define LOG_ERROR(format, ...) do {\
+ abort(); \
+} while (0)
+
+#define LOG_ERROR_AND_THROW(format, ...) do {\
+ abort(); \
+} while (0)
+
+#define DECODING_LOGIC_ERROR(format, ...) do {\
+ abort(); \
+} while (0)
diff --git a/src/microsoft/shortlist/utils/Converter.cpp b/src/microsoft/shortlist/utils/Converter.cpp
new file mode 100644
index 00000000..c28178cd
--- /dev/null
+++ b/src/microsoft/shortlist/utils/Converter.cpp
@@ -0,0 +1,59 @@
+#include "microsoft/shortlist/utils/Converter.h"
+
+namespace quicksand {
+
+#include "microsoft/shortlist/logging/LoggerMacros.h"
+
+
+int64_t Converter::ToInt64(const std::string& str) {
+ return ConvertSingleInternal<int64_t>(str, "int64_t");
+}
+
+uint64_t Converter::ToUInt64(const std::string& str) {
+ return ConvertSingleInternal<uint64_t>(str, "int64_t");
+}
+
+int32_t Converter::ToInt32(const std::string& str) {
+ return ConvertSingleInternal<int32_t>(str, "int32_t");
+}
+
+float Converter::ToFloat(const std::string& str) {
+ // In case the value is out of range of a 32-bit float, but in range of a 64-bit double,
+ // it's better to convert as a double and then do the conersion.
+ return (float)ConvertSingleInternal<double>(str, "float");
+}
+
+double Converter::ToDouble(const std::string& str) {
+ return ConvertSingleInternal<double>(str, "double");
+}
+
+bool Converter::ToBool(const std::string& str) {
+ bool value = false;
+ if (!TryConvert(str, /* out */ value)) {
+ LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type 'bool'", str.c_str());
+ }
+ return value;
+}
+
+std::vector<int32_t> Converter::ToInt32Vector(const std::vector<std::string>& items) {
+ return ConvertVectorInternal<int32_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int32_t");
+}
+
+std::vector<int64_t> Converter::ToInt64Vector(const std::vector<std::string>& items) {
+ return ConvertVectorInternal<int64_t, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "int64_t");
+}
+
+std::vector<float> Converter::ToFloatVector(const std::vector<std::string>& items) {
+ return ConvertVectorInternal<float, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "float");
+}
+
+std::vector<double> Converter::ToDoubleVector(const std::vector<std::string>& items) {
+ return ConvertVectorInternal<double, std::vector<std::string>::const_iterator>(items.begin(), items.end(), "double");
+}
+
+void Converter::HandleConversionError(const std::string& str, const char * type_name) {
+ str; type_name; // make compiler happy
+ LOG_ERROR_AND_THROW("The string '%s' is not interpretable as the type '%s'", str.c_str(), type_name);
+}
+
+} // namespace quicksand
diff --git a/src/microsoft/shortlist/utils/Converter.h b/src/microsoft/shortlist/utils/Converter.h
new file mode 100644
index 00000000..9d9dd96d
--- /dev/null
+++ b/src/microsoft/shortlist/utils/Converter.h
@@ -0,0 +1,83 @@
+#pragma once
+
+#include <stdint.h>
+#include <string>
+#include <vector>
+#include <sstream>
+
+namespace quicksand {
+
+class Converter {
+public:
+ static int32_t ToInt32(const std::string& str);
+
+ static int64_t ToInt64(const std::string& str);
+
+ static uint64_t ToUInt64(const std::string& str);
+
+ static float ToFloat(const std::string& str);
+
+ static double ToDouble(const std::string& str);
+
+ static bool ToBool(const std::string& str);
+
+ static std::vector<int32_t> ToInt32Vector(const std::vector<std::string>& items);
+
+ static std::vector<int64_t> ToInt64Vector(const std::vector<std::string>& items);
+
+ static std::vector<float> ToFloatVector(const std::vector<std::string>& items);
+
+ static std::vector<double> ToDoubleVector(const std::vector<std::string>& items);
+
+ static bool TryConvert(const std::string& str, /* out*/ bool& obj) {
+ if (str == "True" || str == "true" || str == "TRUE" || str == "Yes" || str == "yes" || str == "1") {
+ obj = true;
+ return true;
+ }
+ else if (str == "False" || str == "false" || str == "FALSE" || str == "No" || str == "no" || str == "0") {
+ obj = false;
+ return true;
+ }
+ return false;
+ }
+
+ template <typename T>
+ static bool TryConvert(const std::string& str, /* out*/ T& value) {
+ std::istringstream ss(str);
+ value = T();
+ if (!(ss >> value)) {
+ return false;
+ }
+ return true;
+ }
+
+private:
+ template <typename T>
+ static T ConvertSingleInternal(const std::string& str, const char * type_name);
+
+ template <typename T, typename I>
+ static std::vector<T> ConvertVectorInternal(I begin, I end, const char * type_name);
+
+ static void HandleConversionError(const std::string& str, const char * type_name);
+};
+
+template <typename T>
+T Converter::ConvertSingleInternal(const std::string& str, const char * type_name) {
+ std::istringstream ss(str);
+ T value = T();
+ if (!(ss >> value)) {
+ HandleConversionError(str, type_name);
+ }
+ return value;
+}
+
+template <typename T, typename I>
+std::vector<T> Converter::ConvertVectorInternal(I begin, I end, const char * type_name) {
+ std::vector<T> items;
+ for (I it = begin; it != end; it++) {
+ items.push_back(ConvertSingleInternal<T>(*it, type_name));
+ }
+ return items;
+}
+
+} // namespace quicksand
diff --git a/src/microsoft/shortlist/utils/ParameterTree.cpp b/src/microsoft/shortlist/utils/ParameterTree.cpp
new file mode 100644
index 00000000..465d2e0d
--- /dev/null
+++ b/src/microsoft/shortlist/utils/ParameterTree.cpp
@@ -0,0 +1,417 @@
+#include "microsoft/shortlist/utils/ParameterTree.h"
+
+#include <string>
+
+#include "microsoft/shortlist/utils/StringUtils.h"
+#include "microsoft/shortlist/utils/Converter.h"
+
+namespace quicksand {
+
+#include "microsoft/shortlist/logging/LoggerMacros.h"
+
+std::shared_ptr<ParameterTree> ParameterTree::m_empty_tree = std::make_shared<ParameterTree>("params");
+
+ParameterTree::ParameterTree() {
+ m_name = "root";
+}
+
+ParameterTree::ParameterTree(const std::string& name) {
+ m_name = name;
+}
+
+ParameterTree::~ParameterTree() {
+}
+
+void ParameterTree::Clear() {
+
+}
+
+void ParameterTree::ReplaceVariables(
+ const std::unordered_map<std::string, std::string>& vars,
+ bool error_on_unknown_vars)
+{
+ ReplaceVariablesInternal(vars, error_on_unknown_vars);
+}
+
+void ParameterTree::RegisterInt32(const std::string& name, int32_t * param) {
+ RegisterItemInternal(name, PARAM_TYPE_INT32, (void *)param);
+}
+
+void ParameterTree::RegisterInt64(const std::string& name, int64_t * param) {
+ RegisterItemInternal(name, PARAM_TYPE_INT64, (void *)param);
+}
+
+void ParameterTree::RegisterFloat(const std::string& name, float * param) {
+ RegisterItemInternal(name, PARAM_TYPE_FLOAT, (void *)param);
+}
+
+void ParameterTree::RegisterDouble(const std::string& name, double * param) {
+ RegisterItemInternal(name, PARAM_TYPE_DOUBLE, (void *)param);
+}
+
+void ParameterTree::RegisterBool(const std::string& name, bool * param) {
+ RegisterItemInternal(name, PARAM_TYPE_BOOL, (void *)param);
+}
+
+void ParameterTree::RegisterString(const std::string& name, std::string * param) {
+ RegisterItemInternal(name, PARAM_TYPE_STRING, (void *)param);
+}
+
+std::shared_ptr<ParameterTree> ParameterTree::FromBinaryReader(const void*& current) {
+ std::shared_ptr<ParameterTree> root = std::make_shared<ParameterTree>();
+ root->ReadBinary(current);
+ return root;
+}
+
+void ParameterTree::SetRegisteredParams() {
+ for (std::size_t i = 0; i < m_registered_params.size(); i++) {
+ const RegisteredParam& rp = m_registered_params[i];
+ switch (rp.Type()) {
+ case PARAM_TYPE_INT32:
+ (*(int32_t *)rp.Data()) = GetInt32Req(rp.Name());
+ break;
+ case PARAM_TYPE_INT64:
+ (*(int64_t *)rp.Data()) = GetInt64Req(rp.Name());
+ break;
+ default:
+ LOG_ERROR_AND_THROW("Unknown ParameterType: %d", (int)rp.Type());
+ }
+ }
+}
+
+int32_t ParameterTree::GetInt32Or(const std::string& name, int32_t defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToInt32(*value);
+}
+
+int64_t ParameterTree::GetInt64Or(const std::string& name, int64_t defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToInt64(*value);
+}
+
+uint64_t ParameterTree::GetUInt64Or(const std::string& name, uint64_t defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToUInt64(*value);
+}
+
+double ParameterTree::GetDoubleOr(const std::string& name, double defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToDouble(*value);
+}
+
+float ParameterTree::GetFloatOr(const std::string& name, float defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToFloat(*value);
+}
+
+std::string ParameterTree::GetStringOr(const std::string& name, const std::string& defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return (*value);
+}
+
+bool ParameterTree::GetBoolOr(const std::string& name, bool defaultValue) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return defaultValue;
+ }
+ return Converter::ToBool(*value);
+}
+
+int32_t ParameterTree::GetInt32Req(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToInt32(value);
+}
+
+uint64_t ParameterTree::GetUInt64Req(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToUInt64(value);
+}
+
+int64_t ParameterTree::GetInt64Req(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToInt64(value);
+}
+
+double ParameterTree::GetDoubleReq(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToDouble(value);
+}
+
+float ParameterTree::GetFloatReq(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToFloat(value);
+}
+
+bool ParameterTree::GetBoolReq(const std::string& name) const {
+ std::string value = GetStringReq(name);
+ return Converter::ToBool(value);
+}
+
+std::string ParameterTree::GetStringReq(const std::string& name) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ LOG_ERROR_AND_THROW("Required parameter <%s> not found in ParameterTree:\n%s", name.c_str(), ToString().c_str());
+ }
+ return (*value);
+}
+
+std::vector<std::string> ParameterTree::GetFileListReq(const std::string& name) const {
+ std::vector<std::string> output = GetFileListOptional(name);
+ if (output.size() == 0) {
+ LOG_ERROR_AND_THROW("No files were found for parameter: %s", name.c_str());
+ }
+ return output;
+}
+
+std::vector<std::string> ParameterTree::GetFileListOptional(const std::string& name) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr || (*value).size() == 0) {
+ return std::vector<std::string>();
+ }
+ std::vector<std::string> all_files = StringUtils::Split(*value, ";");
+ return all_files;
+}
+
+std::vector<std::string> ParameterTree::GetStringListReq(const std::string& name, const std::string& sep) const {
+ std::string value = GetStringReq(name);
+ std::vector<std::string> output = StringUtils::Split(value, sep);
+ return output;
+}
+
+std::vector<std::string> ParameterTree::GetStringListOptional(const std::string& name, const std::string& sep) const {
+ std::string value = GetStringOr(name, "");
+ std::vector<std::string> output = StringUtils::Split(value, sep);
+ return output;
+}
+
+std::shared_ptr<ParameterTree> ParameterTree::GetChildReq(const std::string& name) const {
+ for (const auto& child : m_children) {
+ if (child->Name() == name) {
+ return child;
+ }
+ }
+ LOG_ERROR_AND_THROW("Unable to find child ParameterTree with name '%s'", name.c_str());
+ return nullptr; // never happens
+}
+
+
+std::shared_ptr<ParameterTree> ParameterTree::GetChildOrEmpty(const std::string& name) const {
+ for (const auto& child : m_children) {
+ if (child->Name() == name) {
+ return child;
+ }
+ }
+ return std::make_shared<ParameterTree>();
+}
+
+// cast current void pointer to T pointer and move forward by num elements
+template <typename T>
+const T* get(const void*& current, size_t num = 1) {
+ const T* ptr = (const T*)current;
+ current = (const T*)current + num;
+ return ptr;
+}
+
+void ParameterTree::ReadBinary(const void*& current) {
+ auto nameLength = *get<int32_t>(current);
+ auto nameBytes = get<char>(current, nameLength);
+ m_name = std::string(nameBytes, nameBytes + nameLength);
+
+ auto textLength = *get<int32_t>(current);
+ auto textBytes = get<char>(current, textLength);
+ m_text = std::string(textBytes, textBytes + textLength);
+
+ int32_t num_children = *get<int32_t>(current);
+ m_children.resize(num_children);
+ for (int32_t i = 0; i < num_children; i++) {
+ m_children[i].reset(new ParameterTree());
+ m_children[i]->ReadBinary(current);
+ }
+}
+
+std::vector< std::shared_ptr<ParameterTree> > ParameterTree::GetChildren(const std::string& name) const {
+ std::vector< std::shared_ptr<ParameterTree> > children;
+ for (std::shared_ptr<ParameterTree> child : m_children) {
+ if (child->Name() == name) {
+ children.push_back(child);
+ }
+ }
+ return children;
+}
+
+void ParameterTree::AddParam(const std::string& name, const std::string& text) {
+ std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
+ child->SetText(text);
+ m_children.push_back(child);
+}
+
+void ParameterTree::SetParam(const std::string& name, const std::string& text) {
+ for (const auto& child : m_children) {
+ if (child->Name() == name) {
+ child->SetText(text);
+ return;
+ }
+ }
+ std::shared_ptr<ParameterTree> child = std::make_shared<ParameterTree>(name);
+ child->SetText(text);
+ m_children.push_back(child);
+}
+
+void ParameterTree::AddChild(std::shared_ptr<ParameterTree> child) {
+ m_children.push_back(child);
+}
+
+bool ParameterTree::HasParam(const std::string& name) const {
+ const std::string * value = GetParamInternal(name);
+ if (value == nullptr) {
+ return false;
+ }
+ return true;
+}
+
+bool ParameterTree::HasChild(const std::string& name) const {
+ for (const auto& child : m_children) {
+ if (child->Name() == name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+std::string ParameterTree::ToString() const {
+ std::ostringstream ss;
+ ToStringInternal(0, ss);
+ return ss.str();
+}
+
+const std::string * ParameterTree::GetParamInternal(const std::string& name) const {
+ for (const auto& child : m_children) {
+ if (child->Name() == name) {
+ return &(child->Text());
+ }
+ }
+ return nullptr;
+}
+
+
+void ParameterTree::RegisterItemInternal(const std::string& name, ParameterType type, void * param) {
+ if (m_registered_param_names.find(name) != m_registered_param_names.end()) {
+ LOG_ERROR_AND_THROW("Unable to register duplicate parameter name: '%s'", name.c_str());
+ }
+ m_registered_params.push_back(RegisteredParam(name, type, param));
+ m_registered_param_names.insert(name);
+}
+
+void ParameterTree::ToStringInternal(int32_t depth, std::ostream& ss) const {
+ for (int32_t i = 0; i < 2*depth; i++) {
+ ss << " ";
+ }
+ ss << "<" << m_name << ">";
+ if (m_children.size() > 0) {
+ ss << "\n";
+ for (const std::shared_ptr<ParameterTree>& child : m_children) {
+ child->ToStringInternal(depth+1, ss);
+ }
+ for (int32_t i = 0; i < 2 * depth; i++) {
+ ss << " ";
+ }
+ ss << "</" << m_name << ">\n";
+ }
+ else {
+ ss << m_text << "</" << m_name << ">\n";
+ }
+}
+
+std::shared_ptr<ParameterTree> ParameterTree::Clone() const {
+ std::shared_ptr<ParameterTree> node = std::make_shared<ParameterTree>(m_name);
+ node->m_text = m_text;
+ for (auto& child : m_children) {
+ node->m_children.push_back(child->Clone());
+ }
+ return node;
+}
+
+void ParameterTree::Merge(const ParameterTree& other) {
+ m_name = other.m_name;
+ m_text = other.m_text;
+ for (auto& other_child : other.m_children) {
+ if (HasChild(other_child->Name())) {
+ auto my_child = GetChildReq(other_child->Name());
+ if (other_child->Text() != "" && my_child->Text() != "") {
+ my_child->SetText(other_child->Text());
+ }
+ else {
+ my_child->Merge(*other_child);
+ }
+ }
+ else {
+ m_children.push_back(other_child->Clone());
+ }
+ }
+}
+
+void ParameterTree::ReplaceVariablesInternal(
+ const std::unordered_map<std::string, std::string>& vars,
+ bool error_on_unknown_vars)
+{
+ std::size_t offset = 0;
+ std::ostringstream ss;
+ while (true) {
+ std::size_t s_pos = m_text.find("$$", offset);
+ if (s_pos == std::string::npos) {
+ break;
+ }
+ std::size_t e_pos = m_text.find("$$", s_pos + 2);
+ if (e_pos == std::string::npos) {
+ break;
+ }
+
+ if (offset != s_pos) {
+ ss << m_text.substr(offset, s_pos-offset);
+ }
+
+ std::string var_name = m_text.substr(s_pos+2, e_pos - (s_pos+2));
+ auto it = vars.find(var_name);
+ if (it != vars.end()) {
+ std::string value = it->second;
+ ss << value;
+ }
+ else {
+ if (error_on_unknown_vars) {
+ LOG_ERROR_AND_THROW("The variable $$%s$$ was not found", var_name.c_str());
+ }
+ else {
+ ss << "$$" << var_name << "$$";
+ }
+ }
+ offset = e_pos + 2;
+ }
+ ss << m_text.substr(offset);
+
+ m_text = ss.str();
+
+ for (auto& child : m_children) {
+ child->ReplaceVariablesInternal(vars, error_on_unknown_vars);
+ }
+}
+
+} // namespace quicksand
+
diff --git a/src/microsoft/shortlist/utils/ParameterTree.h b/src/microsoft/shortlist/utils/ParameterTree.h
new file mode 100644
index 00000000..1474ff64
--- /dev/null
+++ b/src/microsoft/shortlist/utils/ParameterTree.h
@@ -0,0 +1,185 @@
+#pragma once
+
+#include <string>
+#include <vector>
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+
+#include "microsoft/shortlist/utils/StringUtils.h"
+
+namespace quicksand {
+
+class ParameterTree {
+private:
+ enum ParameterType {
+ PARAM_TYPE_INT32,
+ PARAM_TYPE_INT64,
+ PARAM_TYPE_UINT64,
+ PARAM_TYPE_FLOAT,
+ PARAM_TYPE_DOUBLE,
+ PARAM_TYPE_BOOL,
+ PARAM_TYPE_STRING
+ };
+
+ class RegisteredParam {
+ private:
+ std::string m_name;
+ ParameterType m_type;
+ void * m_data;
+
+ public:
+ RegisteredParam() {}
+
+ RegisteredParam(const std::string& name,
+ ParameterType type,
+ void * data)
+ {
+ m_name = name;
+ m_type = type;
+ m_data = data;
+ }
+
+ const std::string& Name() const {return m_name;}
+ const ParameterType& Type() const {return m_type;}
+ void * Data() const {return m_data;}
+ };
+
+ static std::shared_ptr<ParameterTree> m_empty_tree;
+
+ std::string m_name;
+
+ std::string m_text;
+
+ std::vector< std::shared_ptr<ParameterTree> > m_children;
+
+ std::unordered_set<std::string> m_registered_param_names;
+
+ std::vector<RegisteredParam> m_registered_params;
+
+public:
+ ParameterTree();
+
+ ParameterTree(const std::string& name);
+
+ ~ParameterTree();
+
+ inline const std::string& Text() const { return m_text; }
+ inline void SetText(const std::string& text) { m_text = text; }
+
+ inline const std::string& Name() const { return m_name; }
+ inline void SetName(const std::string& name) { m_name = name; }
+
+ void Clear();
+
+ void ReplaceVariables(
+ const std::unordered_map<std::string, std::string>& vars,
+ bool error_on_unknown_vars = true);
+
+ void RegisterInt32(const std::string& name, int32_t * param);
+
+ void RegisterInt64(const std::string& name, int64_t * param);
+
+ void RegisterFloat(const std::string& name, float * param);
+
+ void RegisterDouble(const std::string& name, double * param);
+
+ void RegisterBool(const std::string& name, bool * param);
+
+ void RegisterString(const std::string& name, std::string * param);
+
+ static std::shared_ptr<ParameterTree> FromBinaryReader(const void*& current);
+
+ void SetRegisteredParams();
+
+ int32_t GetInt32Req(const std::string& name) const;
+
+ int64_t GetInt64Req(const std::string& name) const;
+
+ uint64_t GetUInt64Req(const std::string& name) const;
+
+ double GetDoubleReq(const std::string& name) const;
+
+ float GetFloatReq(const std::string& name) const;
+
+ std::string GetStringReq(const std::string& name) const;
+
+ bool GetBoolReq(const std::string& name) const;
+
+ int32_t GetInt32Or(const std::string& name, int32_t defaultValue) const;
+
+ int64_t GetInt64Or(const std::string& name, int64_t defaultValue) const;
+
+ uint64_t GetUInt64Or(const std::string& name, uint64_t defaultValue) const;
+
+ std::string GetStringOr(const std::string& name, const std::string& defaultValue) const;
+
+ double GetDoubleOr(const std::string& name, double defaultValue) const;
+
+ float GetFloatOr(const std::string& name, float defaultValue) const;
+
+ bool GetBoolOr(const std::string& name, bool defaultValue) const;
+
+ std::vector<std::string> GetFileListReq(const std::string& name) const;
+
+ std::vector<std::string> GetFileListOptional(const std::string& name) const;
+
+ std::vector<std::string> GetStringListReq(const std::string& name, const std::string& sep = " ") const;
+
+ std::vector<std::string> GetStringListOptional(const std::string& name, const std::string& sep = " ") const;
+
+ std::shared_ptr<ParameterTree> GetChildReq(const std::string& name) const;
+
+ std::shared_ptr<ParameterTree> GetChildOrEmpty(const std::string& name) const;
+
+ std::vector< std::shared_ptr<ParameterTree> > GetChildren(const std::string& name) const;
+
+ inline const std::vector< std::shared_ptr<ParameterTree> >& GetChildren() const { return m_children; }
+
+ void ReadBinary(const void*& current);
+
+ void AddParam(const std::string& name, const std::string& text);
+
+ template <typename T>
+ void AddParam(const std::string& name, const T& obj);
+
+ void SetParam(const std::string& name, const std::string& text);
+
+ template <typename T>
+ void SetParam(const std::string& name, const T& obj);
+
+ void AddChild(std::shared_ptr<ParameterTree> child);
+
+ std::string ToString() const;
+
+ bool HasChild(const std::string& name) const;
+
+ bool HasParam(const std::string& name) const;
+
+ std::shared_ptr<ParameterTree> Clone() const;
+
+ void Merge(const ParameterTree& other);
+
+private:
+ void ReplaceVariablesInternal(
+ const std::unordered_map<std::string, std::string>& vars,
+ bool error_on_unknown_vars);
+
+ void RegisterItemInternal(const std::string& name, ParameterType type, void * param);
+
+ const std::string * GetParamInternal(const std::string& name) const;
+
+ void ToStringInternal(int32_t depth, std::ostream& ss) const;
+};
+
+template <typename T>
+void ParameterTree::AddParam(const std::string& name, const T& obj) {
+ AddParam(name, StringUtils::ToString(obj));
+}
+
+template <typename T>
+void ParameterTree::SetParam(const std::string& name, const T& obj) {
+ SetParam(name, StringUtils::ToString(obj));
+}
+
+} // namespace quicksand
diff --git a/src/microsoft/shortlist/utils/PrintTypes.h b/src/microsoft/shortlist/utils/PrintTypes.h
new file mode 100644
index 00000000..6bc1363d
--- /dev/null
+++ b/src/microsoft/shortlist/utils/PrintTypes.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <inttypes.h>
+
+#ifdef QUICKSAND_WINDOWS_BUILD
+#define PI32 "d"
+#define PI64 "lld"
+#define PU32 "u"
+#define PU64 "llu"
+#else
+#define PI32 PRId32
+#define PI64 PRId64
+#define PU32 PRIu32
+#define PU64 PRIu64
+#endif
+
diff --git a/src/microsoft/shortlist/utils/StringUtils.cpp b/src/microsoft/shortlist/utils/StringUtils.cpp
new file mode 100644
index 00000000..7870b542
--- /dev/null
+++ b/src/microsoft/shortlist/utils/StringUtils.cpp
@@ -0,0 +1,338 @@
+#include "microsoft/shortlist/utils/StringUtils.h"
+
+#include <stdio.h>
+#include <algorithm>
+#include <string>
+
+namespace quicksand {
+
+#include "microsoft/shortlist/logging/LoggerMacros.h"
+
+std::string StringUtils::VarArgsToString(const char * format, va_list args) {
+ if (format == nullptr) {
+ LOG_ERROR_AND_THROW("'format' cannot be null in StringUtils::VarArgsToString");
+ }
+
+ std::string output;
+ // Most of the time the stack buffer (5000 chars) will be sufficient.
+ // In cases where this is insufficient, dynamically allocate an appropriately sized buffer
+ char buffer[5000];
+#ifdef QUICKSAND_WINDOWS_BUILD
+ va_list copy;
+ va_copy(copy, args);
+ int ret = vsnprintf_s(buffer, sizeof(buffer), _TRUNCATE, format, copy);
+ va_end(copy);
+ if (ret >= 0) {
+ output = std::string(buffer, buffer + ret);
+ }
+ else {
+ va_list copy2;
+ va_copy(copy2, args);
+ int needed_size = _vscprintf(format, copy2);
+ va_end(copy2);
+
+ if (needed_size < 0) {
+ LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen");
+ }
+ char * dynamic_buffer = new char[needed_size+1];
+ int ret2 = vsnprintf_s(dynamic_buffer, needed_size+1, _TRUNCATE, format, args);
+ if (ret2 >= 0) {
+ output = std::string(dynamic_buffer, dynamic_buffer + ret2);
+ delete[] dynamic_buffer;
+ }
+ else {
+ output = "";
+ delete[] dynamic_buffer;
+ LOG_ERROR_AND_THROW("A call to vsnprintf_s() failed. This should never happen, "
+ "since we made a call to _vscprintf() to check the dynamic buffer size. The call to _vscprintf() "
+ "returned %d bytes, but apparently that was not enough. This would imply a bug in MSVC's vsnprintf_s implementation.", needed_size);
+ }
+ }
+#else
+ va_list copy;
+ va_copy(copy, args);
+ int needed_size = vsnprintf(buffer, sizeof(buffer), format, copy);
+ va_end(copy);
+ if (needed_size < (int)sizeof(buffer)) {
+ output = std::string(buffer, buffer + needed_size);
+ }
+ else {
+ char * dynamic_buffer = new char[needed_size+1];
+ int ret = vsnprintf(dynamic_buffer, needed_size + 1, format, args);
+ if (ret >= 0 && ret < needed_size + 1) {
+ output = std::string(dynamic_buffer);
+ delete[] dynamic_buffer;
+ }
+ else {
+ output = "";
+ delete[] dynamic_buffer;
+ LOG_ERROR_AND_THROW("A call to vsnprintf() failed. Return value: %d.",
+ ret);
+ }
+ }
+#endif
+ return output;
+}
+
+std::vector<std::string> StringUtils::SplitIntoLines(const std::string& input) {
+ std::vector<std::string> output;
+ if (input.size() == 0) {
+ return output;
+ }
+ std::size_t start = 0;
+ for (std::size_t i = 0; i < input.size(); i++) {
+ char c = input[i];
+ if (c == '\r' || c == '\n') {
+ output.push_back(std::string(input.begin() + start, input.begin() + i));
+ start = i+1;
+ }
+ if (c == '\r' && i + 1 < input.size() && input[i+1] == '\n') {
+ i++;
+ start = i+1;
+ }
+ }
+ // do NOT put an empty length trailing line (but empty length intermediate lines are fine)
+ if (input.begin() + start != input.end()) {
+ output.push_back(std::string(input.begin() + start, input.end()));
+ }
+ return output;
+}
+
+bool StringUtils::StartsWith(const std::string& str, const std::string& prefix) {
+ if (str.length() < prefix.length())
+ return false;
+
+ return std::equal(prefix.begin(), prefix.end(), str.begin());
+}
+
+bool StringUtils::EndsWith(const std::string& str, const std::string& suffix) {
+ if (str.length() < suffix.length())
+ return false;
+
+ return std::equal(suffix.begin(), suffix.end(), str.end() - suffix.length());
+}
+
+std::vector<std::string> StringUtils::SplitFileList(const std::string& input) {
+ std::vector<std::string> output;
+ for (const std::string& s : SplitIntoLines(input)) {
+ for (const std::string& t : Split(s, ";")) {
+ std::string f = CleanupWhitespace(t);
+ output.push_back(f);
+ }
+ }
+ return output;
+}
+
+std::vector<std::string> StringUtils::Split(const std::string& input, char splitter) {
+ std::vector<std::string> output;
+ if (input.size() == 0) {
+ return output;
+ }
+ std::size_t start = 0;
+ for (std::size_t i = 0; i < input.size(); i++) {
+ if (input[i] == splitter) {
+ output.push_back(std::string(input.begin() + start, input.begin() + i));
+ start = i+1;
+ }
+ }
+ output.push_back(std::string(input.begin() + start, input.end()));
+ return output;
+}
+
+std::vector<std::string> StringUtils::Split(const std::string& input, const std::string& splitter) {
+ std::vector<std::string> output;
+ if (input.size() == 0) {
+ return output;
+ }
+ std::size_t pos = 0;
+ while (true) {
+ std::size_t next_pos = input.find(splitter, pos);
+ if (next_pos == std::string::npos) {
+ output.push_back(std::string(input.begin() + pos, input.end()));
+ break;
+ }
+ else {
+ output.push_back(std::string(input.begin() + pos, input.begin() + next_pos));
+ }
+ pos = next_pos + splitter.size();
+ }
+ return output;
+}
+
+std::string StringUtils::Join(const std::string& joiner, const uint8_t * items, int32_t length) {
+ std::ostringstream ss;
+ for (int32_t i = 0; i < length; i++) {
+ if (i != 0) {
+ ss << joiner;
+ }
+ ss << (int32_t)(items[i]);
+ }
+ return ss.str();
+}
+
+std::string StringUtils::Join(const std::string& joiner, const int8_t * items, int32_t length) {
+ std::ostringstream ss;
+ for (int32_t i = 0; i < length; i++) {
+ if (i != 0) {
+ ss << joiner;
+ }
+ ss << (int32_t)(items[i]);
+ }
+ return ss.str();
+}
+
+std::string StringUtils::PrintString(const char * format, ...) {
+ va_list args;
+ va_start(args, format);
+ std::string output = StringUtils::VarArgsToString(format, args);
+ va_end(args);
+
+ return output;
+}
+
+std::vector<std::string> StringUtils::WhitespaceTokenize(const std::string& input) {
+ std::vector<std::string> output;
+ if (input.size() == 0) {
+ return output;
+ }
+ std::size_t size = input.size();
+ std::size_t start = 0;
+ std::size_t end = size;
+ for (std::size_t i = 0; i < size; i++) {
+ char c = input[i];
+ if (IsWhitespace(c)) {
+ start++;
+ }
+ else {
+ break;
+ }
+ }
+ for (std::size_t i = 0; i < size; i++) {
+ char c = input[size-1-i];
+ if (IsWhitespace(c)) {
+ end--;
+ }
+ else {
+ break;
+ }
+ }
+ if (end <= start) {
+ return output;
+ }
+ bool prev_is_whitespace = false;
+ std::size_t token_start = start;
+ for (std::size_t i = start; i < end; i++) {
+ char c = input[i];
+ if (IsWhitespace(c)) {
+ if (!prev_is_whitespace) {
+ output.push_back(std::string(input.begin() + token_start, input.begin() + i));
+ }
+ prev_is_whitespace = true;
+ token_start = i+1;
+ }
+ else {
+ prev_is_whitespace = false;
+ }
+ }
+ output.push_back(std::string(input.begin() + token_start, input.begin() + end));
+ return output;
+}
+
+std::string StringUtils::CleanupWhitespace(const std::string& input) {
+ if (input.size() == 0) {
+ return std::string("");
+ }
+ std::size_t size = input.size();
+ std::size_t start = 0;
+ std::size_t end = size;
+ for (std::size_t i = 0; i < size; i++) {
+ char c = input[i];
+ if (IsWhitespace(c)) {
+ start++;
+ }
+ else {
+ break;
+ }
+ }
+ for (std::size_t i = 0; i < size; i++) {
+ char c = input[size-1-i];
+ if (IsWhitespace(c)) {
+ end--;
+ }
+ else {
+ break;
+ }
+ }
+ if (end <= start) {
+ return std::string("");
+ }
+ std::ostringstream ss;
+ bool prev_is_whitespace = false;
+ for (std::size_t i = start; i < end; i++) {
+ char c = input[i];
+ if (IsWhitespace(c)) {
+ if (!prev_is_whitespace) {
+ ss << ' ';
+ }
+ prev_is_whitespace = true;
+ }
+ else {
+ ss << c;
+ prev_is_whitespace = false;
+ }
+ }
+ return ss.str();
+}
+
+std::string StringUtils::XmlEscape(const std::string& str) {
+ std::ostringstream ss;
+ for (std::size_t i = 0; i < str.size(); i++) {
+ char c = str[i];
+ if (c == '&') {
+ ss << "&amp;";
+ }
+ else if (c == '"') {
+ ss << "&quot;";
+ }
+ else if (c == '\'') {
+ ss << "&apos;";
+ }
+ else if (c == '<') {
+ ss << "&lt;";
+ }
+ else if (c == '>') {
+ ss << "&gt;";
+ }
+ else {
+ ss << c;
+ }
+ }
+ return ss.str();
+}
+
+std::string StringUtils::ToString(const std::string& str) {
+ return str;
+}
+
+std::string StringUtils::ToString(bool obj) {
+ return (obj)?"true":"false";
+}
+
+std::string StringUtils::ToUpper(const std::string& str) {
+ std::vector<char> output;
+ output.reserve(str.size());
+ for (char c : str) {
+ output.push_back((char)toupper((int)c));
+ }
+ return std::string(output.begin(), output.end());
+}
+
+std::string StringUtils::ToLower(const std::string& str) {
+ std::ostringstream ss;
+ for (char c : str) {
+ ss << c;
+ }
+ return ss.str();
+}
+
+} // namespace quicksand
diff --git a/src/microsoft/shortlist/utils/StringUtils.h b/src/microsoft/shortlist/utils/StringUtils.h
new file mode 100644
index 00000000..31bb1fcc
--- /dev/null
+++ b/src/microsoft/shortlist/utils/StringUtils.h
@@ -0,0 +1,98 @@
+#pragma once
+
+#include <string>
+#include <sstream>
+#include <stdarg.h>
+#include <vector>
+#include <stdint.h>
+
+#include "microsoft/shortlist/utils/PrintTypes.h"
+
+namespace quicksand {
+
+class StringUtils {
+public:
+ template <typename T>
+ static std::string Join(const std::string& joiner, const T& items);
+
+ template <typename T>
+ static std::string Join(const std::string& joiner, const T * items, int32_t length);
+
+ static std::string Join(const std::string& joiner, const uint8_t * items, int32_t length);
+
+ static std::string Join(const std::string& joiner, const int8_t * items, int32_t length);
+
+ static std::vector<std::string> Split(const std::string& input, char splitter);
+
+ static std::vector<std::string> Split(const std::string& input, const std::string& splitter);
+
+ static std::vector<std::string> SplitFileList(const std::string& input);
+
+ static std::string PrintString(const char * format, ...);
+
+ static std::string VarArgsToString(const char * format, va_list args);
+
+ static std::vector<std::string> WhitespaceTokenize(const std::string& input);
+
+ static std::string CleanupWhitespace(const std::string& input);
+
+ static std::string ToString(const std::string& str);
+
+ static std::string ToString(bool obj);
+
+ template <typename T>
+ static std::string ToString(const T& obj);
+
+ static std::string XmlEscape(const std::string& str);
+
+ static std::vector<std::string> SplitIntoLines(const std::string& input);
+
+ static bool StartsWith(const std::string& str, const std::string& prefix);
+
+ static bool EndsWith(const std::string& str, const std::string& suffix);
+
+ inline static bool IsWhitespace(char c) {
+ return (c == ' ' || c == '\t' || c == '\n' || c == '\r');
+ }
+
+ // This should only be used for ASCII, e.g., filenames, NOT for language data
+ static std::string ToLower(const std::string& str);
+
+ // This should only be used for ASCII, e.g., filenames, NOT for language data
+ static std::string ToUpper(const std::string& str);
+};
+
+template <typename T>
+std::string StringUtils::Join(const std::string& joiner, const T& items) {
+ std::ostringstream ss;
+ bool first = true;
+ for (auto it = items.begin(); it != items.end(); it++) {
+ if (!first) {
+ ss << joiner;
+ }
+ ss << (*it);
+ first = false;
+ }
+ return ss.str();
+}
+
+template <typename T>
+std::string StringUtils::Join(const std::string& joiner, const T * items, int32_t length) {
+ std::ostringstream ss;
+ for (int32_t i = 0; i < length; i++) {
+ if (i != 0) {
+ ss << joiner;
+ }
+ ss << items[i];
+ }
+ return ss.str();
+}
+
+template <typename T>
+std::string StringUtils::ToString(const T& obj) {
+ std::ostringstream ss;
+ ss << obj;
+ return ss.str();
+}
+
+} // namespace quicksand
diff --git a/src/models/amun.h b/src/models/amun.h
index 1bfda269..1bfda269 100755..100644
--- a/src/models/amun.h
+++ b/src/models/amun.h
diff --git a/src/models/bert.h b/src/models/bert.h
index 51427457..51427457 100755..100644
--- a/src/models/bert.h
+++ b/src/models/bert.h
diff --git a/src/models/char_s2s.h b/src/models/char_s2s.h
index 3b9bb2fa..3b9bb2fa 100755..100644
--- a/src/models/char_s2s.h
+++ b/src/models/char_s2s.h
diff --git a/src/models/classifier.h b/src/models/classifier.h
index 9faa907e..9faa907e 100755..100644
--- a/src/models/classifier.h
+++ b/src/models/classifier.h
diff --git a/src/models/costs.cpp b/src/models/costs.cpp
new file mode 100644
index 00000000..c688b211
--- /dev/null
+++ b/src/models/costs.cpp
@@ -0,0 +1,14 @@
+#include "costs.h"
+
+namespace marian {
+namespace models {
+
+Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
+ // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
+ state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
+ // @TODO: This is becoming more and more opaque ^^. Can we simplify this?
+ return state;
+}
+
+} // namespace models
+} // namespace marian
diff --git a/src/models/costs.h b/src/models/costs.h
index 3d8f2c51..e5463bfd 100755..100644
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -4,8 +4,8 @@
#include "layers/guided_alignment.h"
#include "layers/loss.h"
#include "layers/weight.h"
-#include "models/encoder_decoder.h"
#include "models/encoder_classifier.h"
+#include "models/encoder_decoder.h"
#include "models/encoder_pooler.h"
namespace marian {
@@ -22,10 +22,12 @@ namespace models {
class ICost {
public:
- virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
- Ptr<data::Batch> batch,
- bool clearGraph = true) = 0;
+ virtual Ptr<MultiRationalLoss> apply(
+ Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
+ = 0;
virtual ~ICost() {}
};
@@ -45,10 +47,9 @@ public:
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
- toBeWeighted_
- = (options_->hasAndNotEmpty("data-weighting") && !inference_)
- || (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting")
- && !inference_);
+ toBeWeighted_ = (options_->hasAndNotEmpty("data-weighting") && !inference_)
+ || (options_->has("dynamic-weighting")
+ && options_->get<bool>("dynamic-weighting") && !inference_);
if(toBeWeighted_)
weighter_ = WeightingFactory(options_);
}
@@ -56,9 +57,9 @@ public:
virtual ~EncoderDecoderCECost() {}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
auto encdec = std::static_pointer_cast<EncoderDecoder>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@@ -72,17 +73,17 @@ public:
Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
// @TODO: adapt to multi-objective training with multiple decoders
- auto partialLoss = loss_->apply(state->getLogProbs(),
- state->getTargetWords(),
- state->getTargetMask(),
- weights);
+ auto partialLoss = loss_->apply(
+ state->getLogProbs(), state->getTargetWords(), state->getTargetMask(), weights);
multiLoss->push_back(partialLoss);
if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) {
- auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
+ auto attentionVectors
+ = encdec->getDecoders()[0]
+ ->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments");
- auto attention = concatenate(attentionVectors, /*axis =*/ -1);
+ auto attention = concatenate(attentionVectors, /*axis =*/-1);
auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention);
multiLoss->push_back(alignmentLoss);
@@ -109,10 +110,9 @@ public:
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
-
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
auto enccls = std::static_pointer_cast<EncoderClassifier>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
@@ -141,21 +141,20 @@ protected:
public:
EncoderPoolerRankCost(Ptr<Options> options)
- : options_(options),
- inference_(options->get<bool>("inference", false)) {
- auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
- ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
-
- margin_ = std::stof(trainEmbedderRank[0]);
- if(trainEmbedderRank.size() > 1)
- normalizer_ = std::stof(trainEmbedderRank[1]);
+ : options_(options), inference_(options->get<bool>("inference", false)) {
+ auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
+ ABORT_IF(trainEmbedderRank.empty(),
+ "EncoderPoolerRankCost expects train-embedder-rank to be set");
+
+ margin_ = std::stof(trainEmbedderRank[0]);
+ if(trainEmbedderRank.size() > 1)
+ normalizer_ = std::stof(trainEmbedderRank[1]);
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
-
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
@@ -167,28 +166,41 @@ public:
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
// multi-objective training
- auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
- auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
+ auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
+ auto exponent
+ = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
auto dp = exp(exponent);
Expr dn1, dn2;
- if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
- dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
- dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
+ if(normalizer_
+ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the
+ // magnitude of the sum of negative examples in the denominator.
+ dn1 = normalizer_
+ * mean(exp(dotProducts[1] - maxDot),
+ -1); // dot product of anchor and first negative example
+ dn2 = normalizer_
+ * mean(exp(dotProducts[2] - maxDot),
+ -1); // dot product of positive examples and first negative example
} else {
- dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
- dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
+ dn1 = sum(exp(dotProducts[1] - maxDot),
+ -1); // dot product of anchor and first negative example
+ dn2 = sum(exp(dotProducts[2] - maxDot),
+ -1); // dot product of positive examples and first negative example
}
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
- // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
- auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
- auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
- auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
-
+ // Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp
+ // + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
+ auto marginLoss1
+ = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
+ auto marginLoss2
+ = log(dp + dn2)
+ - exponent; // symmetric version of the above with positive example vs negative examples
+ auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
+
RationalLoss loss(marginLoss, (float)dimBatch);
multiLoss->push_back(loss);
-
+
return multiLoss;
}
};
@@ -199,8 +211,7 @@ protected:
Ptr<ICost> cost_;
public:
- Trainer(Ptr<IModel> model, Ptr<ICost> cost)
- : model_(model), cost_(cost) {}
+ Trainer(Ptr<IModel> model, Ptr<ICost> cost) : model_(model), cost_(cost) {}
virtual ~Trainer() {}
@@ -219,8 +230,8 @@ public:
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
return cost_->apply(model_, graph, batch, clearGraph);
};
@@ -230,24 +241,25 @@ public:
class ILogProb {
public:
virtual Logits apply(Ptr<IModel> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) = 0;
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
+ = 0;
};
-// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
-// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
-// This class is for the former use. The latter is done using Trainer.
+// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for
+// the ground truth?
+// Beam search uses it for the former meaning, while 'marian score' and validation in the
+// latter. This class is for the former use. The latter is done using Trainer.
class Scorer : public IModel {
protected:
Ptr<IModel> model_;
Ptr<ILogProb> logProb_;
public:
- Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
- : model_(model), logProb_(cost) {}
+ Scorer(Ptr<IModel> model, Ptr<ILogProb> cost) : model_(model), logProb_(cost) {}
- virtual ~Scorer(){}
+ virtual ~Scorer() {}
Ptr<IModel> getModel() { return model_; }
@@ -264,8 +276,8 @@ public:
}
virtual Logits build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
return logProb_->apply(model_, graph, batch, clearGraph);
};
@@ -282,12 +294,7 @@ public:
class LogSoftmaxStep : public ILogProbStep {
public:
virtual ~LogSoftmaxStep() {}
- virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
- // decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
- state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
- // @TODO: This is becoming more and more opaque ^^. Can we simplify this?
- return state;
- }
+ virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for sampling during beam-search
@@ -298,10 +305,10 @@ public:
virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
- [](Expr logits){ // lemma gets gumbelled
- return logsoftmax(logits + constant_like(logits, inits::gumbel()));
- },
- logsoftmax)); // factors don't
+ [](Expr logits) { // lemma gets gumbelled
+ return logsoftmax(logits + constant_like(logits, inits::gumbel()));
+ },
+ logsoftmax)); // factors don't
return state;
}
};
@@ -316,8 +323,7 @@ protected:
Ptr<ILogProbStep> cost_;
public:
- Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost)
- : encdec_(encdec), cost_(cost) {}
+ Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost) : encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@@ -351,12 +357,13 @@ public:
return encdec_->startState(graph, batch);
}
- virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
- Ptr<DecoderState> state,
- const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
- const Words& words, // [beamIndex * activeBatchSize + batchIndex]
- const std::vector<IndexType>& batchIndices, // [batchIndex]
- int beamSize) override {
+ virtual Ptr<DecoderState> step(
+ Ptr<ExpressionGraph> graph,
+ Ptr<DecoderState> state,
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const Words& words, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
+ int beamSize) override {
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}
@@ -374,9 +381,7 @@ public:
encdec_->setShortlistGenerator(shortlistGenerator);
};
- virtual Ptr<data::Shortlist> getShortlist() override {
- return encdec_->getShortlist();
- };
+ virtual Ptr<data::Shortlist> getShortlist() override { return encdec_->getShortlist(); };
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
};
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 8fc9321a..8fc9321a 100755..100644
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h
index 92c1647f..92c1647f 100755..100644
--- a/src/models/encoder_decoder.h
+++ b/src/models/encoder_decoder.h
diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp
index e176e6a4..e176e6a4 100755..100644
--- a/src/models/model_factory.cpp
+++ b/src/models/model_factory.cpp
diff --git a/src/models/model_factory.h b/src/models/model_factory.h
index 5403b966..5403b966 100755..100644
--- a/src/models/model_factory.h
+++ b/src/models/model_factory.h
diff --git a/src/models/nematus.h b/src/models/nematus.h
index 730418e5..730418e5 100755..100644
--- a/src/models/nematus.h
+++ b/src/models/nematus.h
diff --git a/src/models/s2s.h b/src/models/s2s.h
index 7009fad5..7009fad5 100755..100644
--- a/src/models/s2s.h
+++ b/src/models/s2s.h
diff --git a/src/models/states.h b/src/models/states.h
index c2f9ee05..20dd59c9 100755..100644
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -1,7 +1,7 @@
#pragma once
+#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "marian.h"
-#include "layers/generic.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"
namespace marian {
@@ -9,7 +9,7 @@ namespace marian {
class EncoderState {
private:
Expr context_;
- Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
+ Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
Ptr<data::CorpusBatch> batch_;
public:
@@ -19,31 +19,34 @@ public:
EncoderState() {}
virtual ~EncoderState() {}
- virtual Expr getContext() const { return context_; }
- virtual Expr getAttended() const { return context_; }
- virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed
+ virtual Expr getContext() const { return context_; }
+ virtual Expr getAttended() const { return context_; }
+ virtual Expr getMask() const {
+ return mask_;
+ } // source batch mask; may have additional positions suppressed
- virtual const Words& getSourceWords() {
- return batch_->front()->data();
- }
+ virtual const Words& getSourceWords() { return batch_->front()->data(); }
// Sub-select active batch entries from encoder context and context mask
- Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
- // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
- return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
+ Ptr<EncoderState> select(
+ const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
+ // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer
+ // gets transposed to the same dimension layout
+ return New<EncoderState>(
+ index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
}
};
class DecoderState {
protected:
- rnn::States states_; // states of individual decoder layers
+ rnn::States states_; // states of individual decoder layers
Logits logProbs_;
std::vector<Ptr<EncoderState>> encStates_;
Ptr<data::CorpusBatch> batch_;
- Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
+ Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetMask_;
- Words targetWords_; // target labels
+ Words targetWords_; // target labels
// Keep track of current target token position during translation
size_t position_{0};
@@ -57,26 +60,30 @@ public:
virtual ~DecoderState() {}
// @TODO: Do we need all these to be virtual?
- virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
- return encStates_;
- }
+ virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const { return encStates_; }
virtual Logits getLogProbs() const { return logProbs_; }
virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }
- // @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
- virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
- const std::vector<IndexType>& batchIndices, // [batchIndex]
- int beamSize) const {
-
+ // @TODO: should this be a constructor? Then derived classes can call this without the New<> in
+ // the loop
+ virtual Ptr<DecoderState> select(
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
+ int beamSize) const {
std::vector<Ptr<EncoderState>> newEncStates;
for(auto& es : encStates_)
- // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
- newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
+ // If the size of the batch dimension of the encoder state context changed, subselect the
+ // correct batch entries
+ newEncStates.push_back(
+ es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
// hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
- auto selectedState = New<DecoderState>(
- states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);
+ auto selectedState
+ = New<DecoderState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/false),
+ logProbs_,
+ newEncStates,
+ batch_);
// Set positon of new state based on the target token position of current state
selectedState->setPosition(getPosition());
@@ -86,7 +93,9 @@ public:
virtual const rnn::States& getStates() const { return states_; }
virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
- virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }
+ virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) {
+ targetHistoryEmbeddings_ = targetHistoryEmbeddings;
+ }
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
@@ -94,9 +103,7 @@ public:
virtual Expr getTargetMask() const { return targetMask_; };
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
- virtual const Words& getSourceWords() const {
- return getEncoderStates()[0]->getSourceWords();
- }
+ virtual const Words& getSourceWords() const { return getEncoderStates()[0]->getSourceWords(); }
Ptr<data::CorpusBatch> getBatch() const { return batch_; }
@@ -111,7 +118,8 @@ public:
/**
* Classifier output based on DecoderState
- * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
+ * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have
+ * stateful output.
*/
class ClassifierState {
private:
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 6368cc6a..6368cc6a 100755..100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
diff --git a/src/models/transformer_factory.h b/src/models/transformer_factory.h
index b282d819..b282d819 100755..100644
--- a/src/models/transformer_factory.h
+++ b/src/models/transformer_factory.h
diff --git a/src/models/transformer_stub.cpp b/src/models/transformer_stub.cpp
index 871ee009..871ee009 100755..100644
--- a/src/models/transformer_stub.cpp
+++ b/src/models/transformer_stub.cpp
diff --git a/src/optimizers/exponential_smoothing.cpp b/src/optimizers/exponential_smoothing.cpp
index 1120e7e4..1120e7e4 100755..100644
--- a/src/optimizers/exponential_smoothing.cpp
+++ b/src/optimizers/exponential_smoothing.cpp
diff --git a/src/optimizers/exponential_smoothing.h b/src/optimizers/exponential_smoothing.h
index 5ef12ca1..5ef12ca1 100755..100644
--- a/src/optimizers/exponential_smoothing.h
+++ b/src/optimizers/exponential_smoothing.h
diff --git a/src/rnn/attention.h b/src/rnn/attention.h
index 6b30cb55..6b30cb55 100755..100644
--- a/src/rnn/attention.h
+++ b/src/rnn/attention.h
diff --git a/src/rnn/cells.h b/src/rnn/cells.h
index cddfd26e..cddfd26e 100755..100644
--- a/src/rnn/cells.h
+++ b/src/rnn/cells.h
diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h
index beb1fce1..beb1fce1 100755..100644
--- a/src/rnn/constructors.h
+++ b/src/rnn/constructors.h
diff --git a/src/tensors/rand.cpp b/src/tensors/rand.cpp
index e6dbc46e..e6dbc46e 100755..100644
--- a/src/tensors/rand.cpp
+++ b/src/tensors/rand.cpp
diff --git a/src/tensors/tensor.cpp b/src/tensors/tensor.cpp
index 02de17bc..02de17bc 100755..100644
--- a/src/tensors/tensor.cpp
+++ b/src/tensors/tensor.cpp
diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h
index 10c3e7f1..10c3e7f1 100755..100644
--- a/src/tensors/tensor.h
+++ b/src/tensors/tensor.h
diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp
index 8c06761e..8c06761e 100755..100644
--- a/src/training/graph_group_sync.cpp
+++ b/src/training/graph_group_sync.cpp
diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h
index df7865a7..df7865a7 100755..100644
--- a/src/training/graph_group_sync.h
+++ b/src/training/graph_group_sync.h
diff --git a/src/training/scheduler.h b/src/training/scheduler.h
index 9d2500f9..9d2500f9 100755..100644
--- a/src/training/scheduler.h
+++ b/src/training/scheduler.h
diff --git a/src/training/validator.h b/src/training/validator.h
index d6e64d69..d6e64d69 100755..100644
--- a/src/training/validator.h
+++ b/src/training/validator.h
diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp
index 5c1989a6..5c1989a6 100755..100644
--- a/src/translator/beam_search.cpp
+++ b/src/translator/beam_search.cpp
diff --git a/src/translator/output_printer.h b/src/translator/output_printer.h
index 603eedba..603eedba 100755..100644
--- a/src/translator/output_printer.h
+++ b/src/translator/output_printer.h
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index a5a0be2c..a5a0be2c 100755..100644
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
diff --git a/src/translator/translator.h b/src/translator/translator.h
index 1ff19a4a..82d9343d 100755..100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -60,8 +60,7 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
if(options_->hasAndNotEmpty("shortlist"))
- shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
- options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
+ shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();