Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-02-12 08:06:39 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2021-02-12 08:06:39 +0300
commit147d9dc25be59b195ff612648e03a4bbb442fe7b (patch)
tree261fcb6cea0b8989f49ae512b32296f20451f7e9 /src
parent93cdfdcc9a228a0932b9a1691a48a6ced875f160 (diff)
do not do dropout at inference
Diffstat (limited to 'src')
-rwxr-xr-xsrc/layers/generic.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
index 62752285..d44f4020 100755
--- a/src/layers/generic.cpp
+++ b/src/layers/generic.cpp
@@ -480,7 +480,8 @@ namespace marian {
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
// apply dropout
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
- weights = dropout(weights, dropProb);
+ if(!inference_)
+ weights = dropout(weights, dropProb);
// perform the product
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
}
@@ -552,7 +553,8 @@ namespace marian {
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
- selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
+ if(!inference_)
+ selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
return selectedEmbs;
}