diff options
author | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-07-03 22:13:26 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <marcinjd@microsoft.com> | 2021-07-03 22:13:26 +0300 |
commit | 9772aa293f574aef5fb1a2756ae28ef7428b3dde (patch) | |
tree | 8b17ec0430a469e54bd1bdd867ff0a527f9643a5 /src | |
parent | 8bfa6a44e307a58b80784040d741058855f5b519 (diff) |
remaining comments
Diffstat (limited to 'src')
-rw-r--r-- | src/data/factored_vocab.cpp | 2 | ||||
-rw-r--r-- | src/data/shortlist.cpp | 8 | ||||
-rw-r--r-- | src/data/shortlist.h | 18 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 4 |
4 files changed, 18 insertions, 14 deletions
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp index 4c5207dd..cc715993 100644 --- a/src/data/factored_vocab.cpp +++ b/src/data/factored_vocab.cpp @@ -275,7 +275,7 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() { groupCounts[g]++; } - // required by LSH shortlist + // required by LSH shortlist. Factored segmenter encodes the number of lemmas in the first factor group, this corresponds to actual surface forms lemmaSize_ = groupCounts[0]; for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 9f4a4ebd..f7e229ff 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -19,8 +19,8 @@ const T* get(const void*& current, size_t num = 1) { ////////////////////////////////////////////////////////////////////////////////////// Shortlist::Shortlist(const std::vector<WordIndex>& indices) - : indices_(indices) - , done_(false) {} + : indices_(indices), + initialized_(false) {} Shortlist::~Shortlist() {} @@ -35,7 +35,7 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const { } void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) { - if (done_) { + if (initialized_) { return; } @@ -49,7 +49,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp //std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl; createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k); - done_ = true; + initialized_ = true; } Expr Shortlist::getIndicesExpr() const { diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 1ce8fbf4..a75d2c4b 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -29,13 +29,13 @@ protected: Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; - bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch + bool initialized_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch void createCachedTensors(Expr weights, - bool isLegacyUntransposedW, - Expr b, - Expr lemmaEt, - int k); + bool isLegacyUntransposedW, + Expr b, + Expr lemmaEt, + int k); public: static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos @@ -77,10 +77,10 @@ private: static std::mutex mutex_; void createCachedTensors(Expr weights, - bool isLegacyUntransposedW, - Expr b, - Expr lemmaEt, - int k); + bool isLegacyUntransposedW, + Expr b, + Expr lemmaEt, + int k); public: LSHShortlist(int k, int nbits, size_t lemmaSize); diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index c1570eff..6c7e5758 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -478,6 +478,10 @@ Expr bdot(Expr a, bool transB = false, float scalar = 1.f); +/** + * bdot_legacy is an old implemetation of bdot without correct broadcasting on the batch dimensions, + * to be removed once the behavior can be correctly replicated with normal bdot on 5 dimensions. + */ Expr bdot_legacy(Expr a, Expr b, bool transA = false, |