diff options
author | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 12:51:18 +0300 |
---|---|---|
committer | Tomasz Dwojak <t.dwojak@amu.edu.pl> | 2017-06-19 12:51:18 +0300 |
commit | aea00408732a3d7d35f98ef05e05676207105305 (patch) | |
tree | 4949f97b3d132eb1bd13c590d3a00fcdb648aaa3 | |
parent | 49bdf1d91e386c95a10ffcd935be2bc41a64d05b (diff) |
Refactor Transition
-rw-r--r-- | src/amun/cpu/dl4mt/decoder.h | 4 | ||||
-rw-r--r-- | src/amun/cpu/nematus/transition.cpp | 82 | ||||
-rw-r--r-- | src/amun/cpu/nematus/transition.h | 84 |
3 files changed, 91 insertions, 79 deletions
diff --git a/src/amun/cpu/dl4mt/decoder.h b/src/amun/cpu/dl4mt/decoder.h index cbda3953..0e14d6e5 100644 --- a/src/amun/cpu/dl4mt/decoder.h +++ b/src/amun/cpu/dl4mt/decoder.h @@ -61,10 +61,6 @@ class Decoder { // Repeat mean batchSize times by broadcasting Temp1_ = Mean<byRow, Matrix>(SourceContext); - std::cerr << "CTX: " << std::endl; - for (int i = 0; i < 5; ++i) std::cerr << Temp1_(0, i) << " "; - std::cerr << std::endl; - Temp2_.resize(batchSize, SourceContext.columns()); Temp2_ = 0.0f; AddBiasVector<byRow>(Temp2_, Temp1_); diff --git a/src/amun/cpu/nematus/transition.cpp b/src/amun/cpu/nematus/transition.cpp new file mode 100644 index 00000000..4244a3aa --- /dev/null +++ b/src/amun/cpu/nematus/transition.cpp @@ -0,0 +1,82 @@ +#include "transition.h" + +namespace amunmt { +namespace CPU { +namespace Nematus { + +Transition::Transition(const Weights::Transition& model) + : w_(model), + layerNormalization_(false) +{ + if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) { + layerNormalization_ = true; + } +} + + +void Transition::GetNextState(mblas::Matrix& state) const +{ + if (layerNormalization_) { + for (int i = 0; i < w_.size(); ++i) { + Temp_1_ = state * w_.U_[i]; + Temp_2_ = state * w_.Ux_[i]; + + switch(w_.type()) { + case Weights::Transition::TransitionType::Encoder: + LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + + LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); + break; + + case Weights::Transition::TransitionType::Decoder: + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); + + mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); + LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); + break; + } + ElementwiseOps(state, i); + } + } else { + for (int i = 0; i < w_.size(); ++i) { + Temp_1_ = state * w_.U_[i]; + Temp_2_ = state * w_.Ux_[i]; + mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); + mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); + ElementwiseOps(state, i); + } + } +} + + +void Transition::ElementwiseOps(mblas::Matrix& state, int idx) const { + using namespace mblas; + using namespace blaze; + + for (int j = 0; j < (int)state.Rows(); ++j) { + auto rowState = row(state, j); + auto rowT = row(Temp_1_, j); + auto rowT2 = row(Temp_2_, j); + + for (int i = 0; i < (int)state.Cols(); ++i) { + float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i))); + float r = 1.0f / (1.0f + ev1); + + int k = i + state.Cols(); + float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k))); + float u = 1.0f / (1.0f + ev2); + + float hv = w_.Bx2_[idx](0, i); + float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i); + hv = tanhapprox(hv + r * t2v); + rowState[i] = (1.0f - u) * hv + u * rowState[i]; + } + } +} + +} // namespace Nematus +} // namespace CPU +} // namespace amunmt + diff --git a/src/amun/cpu/nematus/transition.h b/src/amun/cpu/nematus/transition.h index f4d342d7..3db3c72e 100644 --- a/src/amun/cpu/nematus/transition.h +++ b/src/amun/cpu/nematus/transition.h @@ -1,91 +1,24 @@ #pragma once + #include "cpu/mblas/matrix.h" -#include <iomanip> +#include "model.h" namespace amunmt { namespace CPU { +namespace Nematus { -template <class Weights> class Transition { public: - Transition(const Weights& model) - : w_(model), - layerNormalization_(false) - { - if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) { - layerNormalization_ = true; - } - } - - void GetNextState(mblas::Matrix& state) const - { - if (layerNormalization_) { - for (int i = 0; i < w_.size(); ++i) { - Temp_1_ = state * w_.U_[i]; - Temp_2_ = state * w_.Ux_[i]; - - switch(w_.type()) { - case Weights::Transition::TransitionType::Encoder: - LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); - mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); - - LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); - break; - - case Weights::Transition::TransitionType::Decoder: - mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); - LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]); - - mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); - LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]); - break; - } - ElementwiseOps(state, i); - } - } else { - for (int i = 0; i < w_.size(); ++i) { - Temp_1_ = state * w_.U_[i]; - Temp_2_ = state * w_.Ux_[i]; - mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]); - mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]); - ElementwiseOps(state, i); - } - } - } + Transition(const Weights::Transition& model); - void ElementwiseOps(mblas::Matrix& state, int idx) const { - using namespace mblas; - using namespace blaze; - - for (int j = 0; j < (int)state.Rows(); ++j) { - auto rowState = row(state, j); - auto rowT = row(Temp_1_, j); - auto rowT2 = row(Temp_2_, j); - - for (int i = 0; i < (int)state.Cols(); ++i) { - float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i))); - float r = 1.0f / (1.0f + ev1); - - int k = i + state.Cols(); - float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k))); - float u = 1.0f / (1.0f + ev2); - - float hv = w_.Bx2_[idx](0, i); - float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i); - hv = tanhapprox(hv + r * t2v); - rowState[i] = (1.0f - u) * hv + u * rowState[i]; - } - } - } - - size_t GetStateLength() const { - return w_.U_.rows(); - } + void GetNextState(mblas::Matrix& state) const; + protected: + void ElementwiseOps(mblas::Matrix& state, int idx) const; private: // Model matrices - const Weights& w_; + const Weights::Transition& w_; // reused to avoid allocation mutable mblas::Matrix UUx_; @@ -101,4 +34,5 @@ class Transition { } } +} |