Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-07-03 22:13:26 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-07-03 22:13:26 +0300
commit9772aa293f574aef5fb1a2756ae28ef7428b3dde (patch)
tree8b17ec0430a469e54bd1bdd867ff0a527f9643a5 /src
parent8bfa6a44e307a58b80784040d741058855f5b519 (diff)
remaining comments
Diffstat (limited to 'src')
-rw-r--r--src/data/factored_vocab.cpp2
-rw-r--r--src/data/shortlist.cpp8
-rw-r--r--src/data/shortlist.h18
-rw-r--r--src/graph/expression_operators.h4
4 files changed, 18 insertions, 14 deletions
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp
index 4c5207dd..cc715993 100644
--- a/src/data/factored_vocab.cpp
+++ b/src/data/factored_vocab.cpp
@@ -275,7 +275,7 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
groupCounts[g]++;
}
- // required by LSH shortlist
+ // required by LSH shortlist. Factored segmenter encodes the number of lemmas in the first factor group, this corresponds to actual surface forms
lemmaSize_ = groupCounts[0];
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 9f4a4ebd..f7e229ff 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -19,8 +19,8 @@ const T* get(const void*& current, size_t num = 1) {
//////////////////////////////////////////////////////////////////////////////////////
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
- : indices_(indices)
- , done_(false) {}
+ : indices_(indices),
+ initialized_(false) {}
Shortlist::~Shortlist() {}
@@ -35,7 +35,7 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
}
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
- if (done_) {
+ if (initialized_) {
return;
}
@@ -49,7 +49,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
- done_ = true;
+ initialized_ = true;
}
Expr Shortlist::getIndicesExpr() const {
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 1ce8fbf4..a75d2c4b 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -29,13 +29,13 @@ protected:
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
- bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
+ bool initialized_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
void createCachedTensors(Expr weights,
- bool isLegacyUntransposedW,
- Expr b,
- Expr lemmaEt,
- int k);
+ bool isLegacyUntransposedW,
+ Expr b,
+ Expr lemmaEt,
+ int k);
public:
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
@@ -77,10 +77,10 @@ private:
static std::mutex mutex_;
void createCachedTensors(Expr weights,
- bool isLegacyUntransposedW,
- Expr b,
- Expr lemmaEt,
- int k);
+ bool isLegacyUntransposedW,
+ Expr b,
+ Expr lemmaEt,
+ int k);
public:
LSHShortlist(int k, int nbits, size_t lemmaSize);
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index c1570eff..6c7e5758 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -478,6 +478,10 @@ Expr bdot(Expr a,
bool transB = false,
float scalar = 1.f);
+/**
+ * bdot_legacy is an old implemetation of bdot without correct broadcasting on the batch dimensions,
+ * to be removed once the behavior can be correctly replicated with normal bdot on 5 dimensions.
+ */
Expr bdot_legacy(Expr a,
Expr b,
bool transA = false,