Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-04-29 10:08:21 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-04-29 10:08:21 +0300
commit67fe82f8401e83efffb6286893c9d2ea8d967115 (patch)
treef81d7f36a22ba396b6f0eb7ca07e4775733fd3fd /src
parent592854f571e5c114c2e1f9d0469b07f0652381ce (diff)
start broadcast
Diffstat (limited to 'src')
-rw-r--r--src/data/shortlist.cpp44
-rw-r--r--src/data/shortlist.h3
2 files changed, 44 insertions, 3 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp
index 67317f4b..886e74fe 100644
--- a/src/data/shortlist.cpp
+++ b/src/data/shortlist.cpp
@@ -29,16 +29,56 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) {
}
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
- int k = indices_.size();
+ //if (indicesExpr_) return;
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
std::cerr << "batchSize=" << batchSize << std::endl;
- Expr indicesExprBC;
+ auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
+ out->val()->set(indices_);
+ };
+
+ int k = indices_.size();
+ Shape kShape({k});
+ indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
+
+ Expr indicesExprBC = getIndicesExpr(batchSize, currBeamSize);
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k);
}
+Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {
+ int k = indicesExpr_->shape()[0];
+ Expr ones = indicesExpr_->graph()->constant({batchSize, beamSize, 1}, inits::ones(), Type::float32);
+
+ Expr tmp = reshape(indicesExpr_, {1, k});
+ tmp = cast(tmp, Type::float32);
+
+ Expr out = ones * tmp;
+ //debug(out, "out.1");
+
+ auto forward = [](Expr out, const std::vector<Expr>& inputs) {
+ Expr in = inputs[0];
+ const Shape &shape = in->shape();
+ const float *inPtr = in->val()->data();
+ uint32_t *outPtr = out->val()->data<uint32_t>();
+
+ for (int i = 0; i < shape.elements(); ++i) {
+ const float &val = inPtr[i];
+ uint32_t valConv = (uint32_t)val;
+ uint32_t &valOut = outPtr[i];
+ valOut = valConv;
+ //std::cerr << val << " " << valConv << " " << valOut << std::endl;
+ }
+ };
+ out = lambda({out}, out->shape(), Type::uint32, forward);
+ //debug(out, "out.2");
+ //out = cast(out, Type::uint32);
+ //std::cerr << "getIndicesExpr.2=" << out->shape() << std::endl;
+ //out = reshape(out, {k});
+
+ return out;
+}
void Shortlist::broadcast(Expr weights,
bool isLegacyUntransposedW,
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 44da6faa..67a8b74c 100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -21,7 +21,7 @@ namespace data {
class Shortlist {
protected:
std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
-
+ Expr indicesExpr_;
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
@@ -42,6 +42,7 @@ public:
WordIndex tryForwardMap(WordIndex wIdx);
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
+ virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
virtual Expr getCachedShortWt() const { return cachedShortWt_; }
virtual Expr getCachedShortb() const { return cachedShortb_; }
virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; }