diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 10:08:21 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-04-29 10:08:21 +0300 |
commit | 67fe82f8401e83efffb6286893c9d2ea8d967115 (patch) | |
tree | f81d7f36a22ba396b6f0eb7ca07e4775733fd3fd /src | |
parent | 592854f571e5c114c2e1f9d0469b07f0652381ce (diff) |
start broadcast
Diffstat (limited to 'src')
-rw-r--r-- | src/data/shortlist.cpp | 44 | ||||
-rw-r--r-- | src/data/shortlist.h | 3 |
2 files changed, 44 insertions, 3 deletions
diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 67317f4b..886e74fe 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -29,16 +29,56 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) { } void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) { - int k = indices_.size(); + //if (indicesExpr_) return; int currBeamSize = input->shape()[0]; int batchSize = input->shape()[2]; std::cerr << "currBeamSize=" << currBeamSize << std::endl; std::cerr << "batchSize=" << batchSize << std::endl; - Expr indicesExprBC; + auto forward = [this](Expr out, const std::vector<Expr>& inputs) { + out->val()->set(indices_); + }; + + int k = indices_.size(); + Shape kShape({k}); + indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward); + + Expr indicesExprBC = getIndicesExpr(batchSize, currBeamSize); broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k); } +Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const { + int k = indicesExpr_->shape()[0]; + Expr ones = indicesExpr_->graph()->constant({batchSize, beamSize, 1}, inits::ones(), Type::float32); + + Expr tmp = reshape(indicesExpr_, {1, k}); + tmp = cast(tmp, Type::float32); + + Expr out = ones * tmp; + //debug(out, "out.1"); + + auto forward = [](Expr out, const std::vector<Expr>& inputs) { + Expr in = inputs[0]; + const Shape &shape = in->shape(); + const float *inPtr = in->val()->data(); + uint32_t *outPtr = out->val()->data<uint32_t>(); + + for (int i = 0; i < shape.elements(); ++i) { + const float &val = inPtr[i]; + uint32_t valConv = (uint32_t)val; + uint32_t &valOut = outPtr[i]; + valOut = valConv; + //std::cerr << val << " " << valConv << " " << valOut << std::endl; + } + }; + out = lambda({out}, out->shape(), Type::uint32, forward); + //debug(out, "out.2"); + //out = cast(out, Type::uint32); + //std::cerr << "getIndicesExpr.2=" << out->shape() << std::endl; + //out = reshape(out, {k}); + + return out; +} void Shortlist::broadcast(Expr weights, bool isLegacyUntransposedW, diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 44da6faa..67a8b74c 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -21,7 +21,7 @@ namespace data { class Shortlist { protected: std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings - + Expr indicesExpr_; Expr cachedShortWt_; // short-listed version, cached (cleared by clear()) Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; @@ -42,6 +42,7 @@ public: WordIndex tryForwardMap(WordIndex wIdx); virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); + virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const; virtual Expr getCachedShortWt() const { return cachedShortWt_; } virtual Expr getCachedShortb() const { return cachedShortb_; } virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; } |