Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/amun/cpu/dl4mt/decoder.h')
-rw-r--r--src/amun/cpu/dl4mt/decoder.h12
1 files changed, 5 insertions, 7 deletions
diff --git a/src/amun/cpu/dl4mt/decoder.h b/src/amun/cpu/dl4mt/decoder.h
index 607d41d8..b1c10e97 100644
--- a/src/amun/cpu/dl4mt/decoder.h
+++ b/src/amun/cpu/dl4mt/decoder.h
@@ -182,14 +182,13 @@ class Decoder {
auto t = blaze::forEach(T1_ + T2_ + T3_, Tanh());
if(!filtered_) {
- Probs_ = t * w_.W4_;
- AddBiasVector<byRow>(Probs_, w_.B4_);
+ Probs = t * w_.W4_;
+ AddBiasVector<byRow>(Probs, w_.B4_);
} else {
- Probs_ = t * FilteredW4_;
- AddBiasVector<byRow>(Probs_, FilteredB4_);
+ Probs = t * FilteredW4_;
+ AddBiasVector<byRow>(Probs, FilteredB4_);
}
- mblas::Softmax(Probs_);
- Probs = blaze::forEach(Probs_, Log());
+ LogSoftmax(Probs);
}
void Filter(const std::vector<size_t>& ids) {
@@ -209,7 +208,6 @@ class Decoder {
mblas::Matrix T1_;
mblas::Matrix T2_;
mblas::Matrix T3_;
- mblas::Matrix Probs_;
};
public: