Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-07 07:23:04 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-08-07 07:23:04 +0300
commit34aa79d17ee0c119e03762162fb742f24f94ab4b (patch)
treeb27d706ade5fd9be07552e6d851046c50fe4cac2
parentb51a2a7e4121562c70b755be618812d1ff4173d1 (diff)
add rnn layer to transformer
-rw-r--r--src/models/transformer.h42
-rw-r--r--src/rnn/cells.h368
-rw-r--r--src/rnn/constructors.h12
3 files changed, 422 insertions, 0 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 25fd1414..e495d620 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -437,6 +437,41 @@ public:
return LayerAAN(prefix, input, output);
}
+
+ Expr DecoderLayerRNN(rnn::State& decoderState,
+ const rnn::State& prevDecoderState,
+ std::string prefix,
+ Expr input,
+ Expr selfMask,
+ int startPos) const {
+ using namespace keywords;
+
+ float dropoutRnn = inference_ ? 0.f : opt<float>("dropout-rnn");
+
+ auto rnn = rnn::rnn(graph_) //
+ ("type", opt<std::string>("dec-cell")) //
+ ("prefix", prefix) //
+ ("dimInput", opt<int>("dim-emb")) //
+ ("dimState", opt<int>("dim-emb")) //
+ ("dropout", dropoutRnn) //
+ ("layer-normalization", opt<bool>("layer-normalization")) //
+ .push_back(rnn::cell(graph_)) //
+ .construct();
+
+ float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
+ auto opsPre = opt<std::string>("transformer-preprocess");
+ auto output = preProcess(prefix, opsPre, input, dropProb);
+
+ output = transposeTimeBatch(output);
+ output = rnn->transduce(output, prevDecoderState);
+ decoderState = rnn->lastCellStates()[0];
+ output = transposeTimeBatch(output);
+
+ auto opsPost = opt<std::string>("transformer-postprocess");
+ output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
+
+ return output;
+ }
};
class EncoderTransformer : public Transformer<EncoderBase> {
@@ -738,6 +773,13 @@ public:
query,
selfMask,
startPos);
+ else if(layerType == "rnn")
+ query = DecoderLayerRNN(decoderState,
+ prevDecoderState,
+ prefix_ + "_l" + std::to_string(i) + "_rnn",
+ query,
+ selfMask,
+ startPos);
else
ABORT("Unknown auto-regressive layer type in transformer decoder {}",
layerType);
diff --git a/src/rnn/cells.h b/src/rnn/cells.h
index 67750ac4..d245f891 100644
--- a/src/rnn/cells.h
+++ b/src/rnn/cells.h
@@ -108,6 +108,104 @@ public:
/******************************************************************************/
+class ReLU : public Cell {
+private:
+ Expr U_, W_, b_;
+ Expr gamma1_;
+ Expr gamma2_;
+
+ bool layerNorm_;
+ float dropout_;
+
+ Expr dropMaskX_;
+ Expr dropMaskS_;
+
+public:
+ ReLU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
+ int dimInput = options_->get<int>("dimInput");
+ int dimState = options_->get<int>("dimState");
+ std::string prefix = options_->get<std::string>("prefix");
+
+ layerNorm_ = options_->get<bool>("layer-normalization", false);
+ dropout_ = options_->get<float>("dropout", 0);
+
+ U_ = graph->param(prefix + "_U",
+ {dimState, dimState},
+ inits::diag(1.f));
+
+ if(dimInput)
+ W_ = graph->param(prefix + "_W",
+ {dimInput, dimState},
+ inits::glorot_uniform);
+
+ b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros);
+
+ if(dropout_ > 0.0f) {
+ if(dimInput)
+ dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ }
+
+ if(layerNorm_) {
+ if(dimInput)
+ gamma1_ = graph->param(prefix + "_gamma1",
+ {1, dimState},
+ inits::ones);
+ gamma2_ = graph->param(prefix + "_gamma2",
+ {1, dimState},
+ inits::ones);
+ }
+ }
+
+ State apply(std::vector<Expr> inputs, State states, Expr mask = nullptr) {
+ return applyState(applyInput(inputs), states, mask);
+ }
+
+ std::vector<Expr> applyInput(std::vector<Expr> inputs) {
+ Expr input;
+ if(inputs.size() == 0)
+ return {};
+ else if(inputs.size() > 1)
+ input = concatenate(inputs, keywords::axis = -1);
+ else
+ input = inputs.front();
+
+ if(dropMaskX_)
+ input = dropout(input, dropMaskX_);
+
+ auto xW = dot(input, W_);
+
+ if(layerNorm_)
+ xW = layerNorm(xW, gamma1_);
+
+ return {xW};
+ }
+
+ State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) {
+ Expr recState = state.output;
+
+ auto stateDropped = recState;
+ if(dropMaskS_)
+ stateDropped = dropout(recState, dropMaskS_);
+ auto sU = dot(stateDropped, U_);
+ if(layerNorm_)
+ sU = layerNorm(sU, gamma2_);
+
+ Expr output;
+ if(xWs.empty())
+ output = relu(sU + b_);
+ else {
+ output = relu(xWs.front() + sU + b_);
+ }
+ if(mask)
+ return {output * mask, state.cell};
+ else
+ return {output, state.cell};
+ }
+};
+
+/******************************************************************************/
+
Expr gruOps(const std::vector<Expr>& nodes, bool final = false);
class GRU : public Cell {
@@ -802,5 +900,275 @@ public:
return {nextRecState, nextCellState};
}
};
+
+class SRU : public Cell {
+private:
+ Expr W_;
+ Expr Wr_, br_;
+ Expr Wf_, bf_;
+
+ float dropout_;
+ Expr dropMaskX_;
+
+ float layerNorm_;
+ Expr gamma_, gammaf_, gammar_;
+
+public:
+ SRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
+ int dimInput = opt<int>("dimInput");
+ int dimState = opt<int>("dimState");
+ std::string prefix = opt<std::string>("prefix");
+
+ ABORT_IF(dimInput != dimState, "For SRU state and input dims have to be equal");
+
+ dropout_ = opt<float>("dropout", 0);
+
+ W_ = graph->param(prefix + "_W",
+ {dimInput, dimInput},
+ inits::glorot_uniform);
+
+ Wf_ = graph->param(prefix + "_Wf",
+ {dimInput, dimInput},
+ inits::glorot_uniform);
+ bf_ = graph->param(
+ prefix + "_bf", {1, dimInput}, inits::zeros);
+
+ Wr_ = graph->param(prefix + "_Wr",
+ {dimInput, dimInput},
+ inits::glorot_uniform);
+ br_ = graph->param(
+ prefix + "_br", {1, dimInput}, inits::zeros);
+
+ if(dropout_ > 0.0f) {
+ dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+ }
+
+ if(layerNorm_) {
+ if(dimInput)
+ gamma_ = graph->param(prefix + "_gamma",
+ {1, dimState},
+ inits::ones);
+ gammar_ = graph->param(prefix + "_gammar",
+ {1, dimState},
+ inits::ones);
+ gammaf_ = graph->param(prefix + "_gammaf",
+ {1, dimState},
+ inits::ones);
+ }
+ }
+
+ State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
+ return applyState(applyInput(inputs), state, mask);
+ }
+
+ std::vector<Expr> applyInput(std::vector<Expr> inputs) {
+ ABORT_IF(inputs.empty(), "SRU expects input");
+
+ Expr input;
+ if(inputs.size() > 1)
+ input = concatenate(inputs, keywords::axis = -1);
+ else
+ input = inputs.front();
+
+ auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+
+ Expr x, f, r;
+ if(layerNorm_) {
+ x = layerNorm(dot(inputDropped, W_), gamma_);
+ f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_);
+ r = layerNorm(dot(inputDropped, Wr_), gammar_, br_);
+ } else {
+ x = dot(inputDropped, W_);
+ f = affine(inputDropped, Wf_, bf_);
+ r = affine(inputDropped, Wr_, br_);
+ }
+
+ return {x, f, r, input};
+ }
+
+ State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) {
+ auto recState = state.output;
+ auto cellState = state.cell;
+
+ auto x = xWs[0];
+ auto f = xWs[1];
+ auto r = xWs[2];
+ auto input = xWs[3];
+
+ auto nextCellState = highway(cellState, x, f); // rename to "gate"?
+ auto nextState = highway(tanh(nextCellState), input, r);
+
+ auto maskedCellState = mask ? mask * nextCellState : nextCellState;
+ auto maskedState = mask ? mask * nextState : nextState;
+
+ return {maskedState, maskedCellState};
+ }
+};
+
+class SSRU : public Cell {
+private:
+ Expr W_;
+ Expr Wf_, bf_;
+
+ float dropout_;
+ Expr dropMaskX_;
+
+ float layerNorm_;
+ Expr gamma_, gammaf_;
+
+public:
+ SSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
+ int dimInput = options_->get<int>("dimInput");
+ int dimState = options_->get<int>("dimState");
+ std::string prefix = options->get<std::string>("prefix");
+
+ ABORT_IF(dimInput != dimState, "For SSRU state and input dims have to be equal");
+
+ dropout_ = opt<float>("dropout", 0);
+
+ W_ = graph->param(prefix + "_W",
+ {dimInput, dimInput},
+ inits::glorot_uniform);
+
+ Wf_ = graph->param(prefix + "_Wf",
+ {dimInput, dimInput},
+ inits::glorot_uniform);
+ bf_ = graph->param(
+ prefix + "_bf", {1, dimInput}, inits::zeros);
+
+
+ if(dropout_ > 0.0f) {
+ dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+ }
+
+ if(layerNorm_) {
+ if(dimInput)
+ gamma_ = graph->param(prefix + "_gamma",
+ {1, dimState},
+ inits::ones);
+ gammaf_ = graph->param(prefix + "_gammaf",
+ {1, dimState},
+ inits::ones);
+ }
+ }
+
+ State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
+ return applyState(applyInput(inputs), state, mask);
+ }
+
+ std::vector<Expr> applyInput(std::vector<Expr> inputs) {
+ ABORT_IF(inputs.empty(), "SSRU expects input");
+
+ Expr input;
+ if(inputs.size() > 1)
+ input = concatenate(inputs, keywords::axis = -1);
+ else
+ input = inputs.front();
+
+ auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+
+ Expr x, f;
+ if(layerNorm_) {
+ x = layerNorm(dot(inputDropped, W_), gamma_);
+ f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_);
+ } else {
+ x = dot(inputDropped, W_);
+ f = affine(inputDropped, Wf_, bf_);
+ }
+
+ return {x, f};
+ }
+
+ State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) {
+ auto recState = state.output;
+ auto cellState = state.cell;
+
+ auto x = xWs[0];
+ auto f = xWs[1];
+
+ auto nextCellState = highway(cellState, x, f); // rename to "gate"?
+ auto nextState = relu(nextCellState);
+
+ auto maskedCellState = mask ? mask * nextCellState : nextCellState;
+ auto maskedState = mask ? mask * nextState : nextState;
+
+ return {maskedState, maskedCellState};
+ }
+};
+
+// class LSSRU : public Cell {
+// private:
+// Expr W_;
+// Expr Wf_, bf_;
+
+// float dropout_;
+// Expr dropMaskX_;
+
+// public:
+// LSSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) {
+// int dimInput = options_->get<int>("dimInput");
+// int dimState = options_->get<int>("dimState");
+// std::string prefix = options->get<std::string>("prefix");
+
+// ABORT_IF(dimInput != dimState, "For SRU state and input dims have to be equal");
+
+// dropout_ = opt<float>("dropout", 0);
+
+// W_ = graph->param(prefix + "_W",
+// {dimInput, dimInput},
+// inits::glorot_uniform);
+
+// Wf_ = graph->param(prefix + "_Wf",
+// {dimInput, dimInput},
+// inits::glorot_uniform);
+// bf_ = graph->param(
+// prefix + "_bf", {1, dimInput}, inits::zeros);
+
+
+// if(dropout_ > 0.0f) {
+// dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+// }
+// }
+
+// State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
+// return applyState(applyInput(inputs), state, mask);
+// }
+
+// std::vector<Expr> applyInput(std::vector<Expr> inputs) {
+// ABORT_IF(inputs.empty(), "Slow SRU expects input");
+
+// Expr input;
+// if(inputs.size() > 1)
+// input = concatenate(inputs, keywords::axis = -1);
+// else
+// input = inputs.front();
+
+// auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+
+// auto x = dot(inputDropped, W_);
+// auto f = affine(inputDropped, Wf_, bf_);
+
+// return {x, f};
+// }
+
+// State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) {
+// auto recState = state.output;
+// auto cellState = state.cell;
+
+// auto x = xWs[0];
+// auto f = xWs[1];
+
+// auto nextCellState = highwayLinear(cellState, x, f, 2.f); // rename to "gate"?
+// auto nextState = relu(nextCellState);
+// //auto nextState = nextCellState;
+
+// auto maskedCellState = mask ? mask * nextCellState : nextCellState;
+// auto maskedState = mask ? mask * nextState : nextState;
+
+// return {maskedState, maskedCellState};
+// }
+// };
+
+
} // namespace rnn
} // namespace marian
diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h
index 8aad788a..30351dc2 100644
--- a/src/rnn/constructors.h
+++ b/src/rnn/constructors.h
@@ -63,6 +63,18 @@ public:
auto cell = New<Tanh>(graph_, options_);
cell->setLazyInputs(inputs_);
return cell;
+ } else if(type == "relu") {
+ auto cell = New<ReLU>(graph_, options_);
+ cell->setLazyInputs(inputs_);
+ return cell;
+ } else if(type == "sru") {
+ auto cell = New<SRU>(graph_, options_);
+ cell->setLazyInputs(inputs_);
+ return cell;
+ } else if(type == "ssru") {
+ auto cell = New<SSRU>(graph_, options_);
+ cell->setLazyInputs(inputs_);
+ return cell;
} else {
ABORT("Unknown RNN cell type");
}