Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 12:51:18 +0300
committerTomasz Dwojak <t.dwojak@amu.edu.pl>2017-06-19 12:51:18 +0300
commitaea00408732a3d7d35f98ef05e05676207105305 (patch)
tree4949f97b3d132eb1bd13c590d3a00fcdb648aaa3
parent49bdf1d91e386c95a10ffcd935be2bc41a64d05b (diff)
Refactor Transition
-rw-r--r--src/amun/cpu/dl4mt/decoder.h4
-rw-r--r--src/amun/cpu/nematus/transition.cpp82
-rw-r--r--src/amun/cpu/nematus/transition.h84
3 files changed, 91 insertions, 79 deletions
diff --git a/src/amun/cpu/dl4mt/decoder.h b/src/amun/cpu/dl4mt/decoder.h
index cbda3953..0e14d6e5 100644
--- a/src/amun/cpu/dl4mt/decoder.h
+++ b/src/amun/cpu/dl4mt/decoder.h
@@ -61,10 +61,6 @@ class Decoder {
// Repeat mean batchSize times by broadcasting
Temp1_ = Mean<byRow, Matrix>(SourceContext);
- std::cerr << "CTX: " << std::endl;
- for (int i = 0; i < 5; ++i) std::cerr << Temp1_(0, i) << " ";
- std::cerr << std::endl;
-
Temp2_.resize(batchSize, SourceContext.columns());
Temp2_ = 0.0f;
AddBiasVector<byRow>(Temp2_, Temp1_);
diff --git a/src/amun/cpu/nematus/transition.cpp b/src/amun/cpu/nematus/transition.cpp
new file mode 100644
index 00000000..4244a3aa
--- /dev/null
+++ b/src/amun/cpu/nematus/transition.cpp
@@ -0,0 +1,82 @@
+#include "transition.h"
+
+namespace amunmt {
+namespace CPU {
+namespace Nematus {
+
+Transition::Transition(const Weights::Transition& model)
+ : w_(model),
+ layerNormalization_(false)
+{
+ if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) {
+ layerNormalization_ = true;
+ }
+}
+
+
+void Transition::GetNextState(mblas::Matrix& state) const
+{
+ if (layerNormalization_) {
+ for (int i = 0; i < w_.size(); ++i) {
+ Temp_1_ = state * w_.U_[i];
+ Temp_2_ = state * w_.Ux_[i];
+
+ switch(w_.type()) {
+ case Weights::Transition::TransitionType::Encoder:
+ LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+
+ LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
+ break;
+
+ case Weights::Transition::TransitionType::Decoder:
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+ LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
+
+ mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
+ LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
+ break;
+ }
+ ElementwiseOps(state, i);
+ }
+ } else {
+ for (int i = 0; i < w_.size(); ++i) {
+ Temp_1_ = state * w_.U_[i];
+ Temp_2_ = state * w_.Ux_[i];
+ mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
+ mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
+ ElementwiseOps(state, i);
+ }
+ }
+}
+
+
+void Transition::ElementwiseOps(mblas::Matrix& state, int idx) const {
+ using namespace mblas;
+ using namespace blaze;
+
+ for (int j = 0; j < (int)state.Rows(); ++j) {
+ auto rowState = row(state, j);
+ auto rowT = row(Temp_1_, j);
+ auto rowT2 = row(Temp_2_, j);
+
+ for (int i = 0; i < (int)state.Cols(); ++i) {
+ float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i)));
+ float r = 1.0f / (1.0f + ev1);
+
+ int k = i + state.Cols();
+ float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k)));
+ float u = 1.0f / (1.0f + ev2);
+
+ float hv = w_.Bx2_[idx](0, i);
+ float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i);
+ hv = tanhapprox(hv + r * t2v);
+ rowState[i] = (1.0f - u) * hv + u * rowState[i];
+ }
+ }
+}
+
+} // namespace Nematus
+} // namespace CPU
+} // namespace amunmt
+
diff --git a/src/amun/cpu/nematus/transition.h b/src/amun/cpu/nematus/transition.h
index f4d342d7..3db3c72e 100644
--- a/src/amun/cpu/nematus/transition.h
+++ b/src/amun/cpu/nematus/transition.h
@@ -1,91 +1,24 @@
#pragma once
+
#include "cpu/mblas/matrix.h"
-#include <iomanip>
+#include "model.h"
namespace amunmt {
namespace CPU {
+namespace Nematus {
-template <class Weights>
class Transition {
public:
- Transition(const Weights& model)
- : w_(model),
- layerNormalization_(false)
- {
- if (w_.U_lns_.size() > 1 && w_.U_lns_[0].rows() > 1) {
- layerNormalization_ = true;
- }
- }
-
- void GetNextState(mblas::Matrix& state) const
- {
- if (layerNormalization_) {
- for (int i = 0; i < w_.size(); ++i) {
- Temp_1_ = state * w_.U_[i];
- Temp_2_ = state * w_.Ux_[i];
-
- switch(w_.type()) {
- case Weights::Transition::TransitionType::Encoder:
- LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
- mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
-
- LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
- break;
-
- case Weights::Transition::TransitionType::Decoder:
- mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
- LayerNormalization(Temp_1_, w_.U_lns_[i], w_.U_lnb_[i]);
-
- mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
- LayerNormalization(Temp_2_, w_.Ux_lns_[i], w_.Ux_lnb_[i]);
- break;
- }
- ElementwiseOps(state, i);
- }
- } else {
- for (int i = 0; i < w_.size(); ++i) {
- Temp_1_ = state * w_.U_[i];
- Temp_2_ = state * w_.Ux_[i];
- mblas::AddBiasVector<mblas::byRow>(Temp_1_, w_.B_[i]);
- mblas::AddBiasVector<mblas::byRow>(Temp_2_, w_.Bx1_[i]);
- ElementwiseOps(state, i);
- }
- }
- }
+ Transition(const Weights::Transition& model);
- void ElementwiseOps(mblas::Matrix& state, int idx) const {
- using namespace mblas;
- using namespace blaze;
-
- for (int j = 0; j < (int)state.Rows(); ++j) {
- auto rowState = row(state, j);
- auto rowT = row(Temp_1_, j);
- auto rowT2 = row(Temp_2_, j);
-
- for (int i = 0; i < (int)state.Cols(); ++i) {
- float ev1 = expapprox(-(rowT[i])); // + w_.B_[idx](0, i)));
- float r = 1.0f / (1.0f + ev1);
-
- int k = i + state.Cols();
- float ev2 = expapprox(-(rowT[k])); // + w_.B_[idx](0, k)));
- float u = 1.0f / (1.0f + ev2);
-
- float hv = w_.Bx2_[idx](0, i);
- float t2v = rowT2[i]; // + w_.Bx1_[idx](0, i);
- hv = tanhapprox(hv + r * t2v);
- rowState[i] = (1.0f - u) * hv + u * rowState[i];
- }
- }
- }
-
- size_t GetStateLength() const {
- return w_.U_.rows();
- }
+ void GetNextState(mblas::Matrix& state) const;
+ protected:
+ void ElementwiseOps(mblas::Matrix& state, int idx) const;
private:
// Model matrices
- const Weights& w_;
+ const Weights::Transition& w_;
// reused to avoid allocation
mutable mblas::Matrix UUx_;
@@ -101,4 +34,5 @@ class Transition {
}
}
+}