diff options
Diffstat (limited to 'src/amun/cpu/dl4mt/decoder.h')
-rw-r--r-- | src/amun/cpu/dl4mt/decoder.h | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/src/amun/cpu/dl4mt/decoder.h b/src/amun/cpu/dl4mt/decoder.h index 607d41d8..b1c10e97 100644 --- a/src/amun/cpu/dl4mt/decoder.h +++ b/src/amun/cpu/dl4mt/decoder.h @@ -182,14 +182,13 @@ class Decoder { auto t = blaze::forEach(T1_ + T2_ + T3_, Tanh()); if(!filtered_) { - Probs_ = t * w_.W4_; - AddBiasVector<byRow>(Probs_, w_.B4_); + Probs = t * w_.W4_; + AddBiasVector<byRow>(Probs, w_.B4_); } else { - Probs_ = t * FilteredW4_; - AddBiasVector<byRow>(Probs_, FilteredB4_); + Probs = t * FilteredW4_; + AddBiasVector<byRow>(Probs, FilteredB4_); } - mblas::Softmax(Probs_); - Probs = blaze::forEach(Probs_, Log()); + LogSoftmax(Probs); } void Filter(const std::vector<size_t>& ids) { @@ -209,7 +208,6 @@ class Decoder { mblas::Matrix T1_; mblas::Matrix T2_; mblas::Matrix T3_; - mblas::Matrix Probs_; }; public: |