diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-08-07 07:23:04 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-08-07 07:23:04 +0300 |
commit | 34aa79d17ee0c119e03762162fb742f24f94ab4b (patch) | |
tree | b27d706ade5fd9be07552e6d851046c50fe4cac2 | |
parent | b51a2a7e4121562c70b755be618812d1ff4173d1 (diff) |
add rnn layer to transformer
-rw-r--r-- | src/models/transformer.h | 42 | ||||
-rw-r--r-- | src/rnn/cells.h | 368 | ||||
-rw-r--r-- | src/rnn/constructors.h | 12 |
3 files changed, 422 insertions, 0 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h index 25fd1414..e495d620 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -437,6 +437,41 @@ public: return LayerAAN(prefix, input, output); } + + Expr DecoderLayerRNN(rnn::State& decoderState, + const rnn::State& prevDecoderState, + std::string prefix, + Expr input, + Expr selfMask, + int startPos) const { + using namespace keywords; + + float dropoutRnn = inference_ ? 0.f : opt<float>("dropout-rnn"); + + auto rnn = rnn::rnn(graph_) // + ("type", opt<std::string>("dec-cell")) // + ("prefix", prefix) // + ("dimInput", opt<int>("dim-emb")) // + ("dimState", opt<int>("dim-emb")) // + ("dropout", dropoutRnn) // + ("layer-normalization", opt<bool>("layer-normalization")) // + .push_back(rnn::cell(graph_)) // + .construct(); + + float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); + auto opsPre = opt<std::string>("transformer-preprocess"); + auto output = preProcess(prefix, opsPre, input, dropProb); + + output = transposeTimeBatch(output); + output = rnn->transduce(output, prevDecoderState); + decoderState = rnn->lastCellStates()[0]; + output = transposeTimeBatch(output); + + auto opsPost = opt<std::string>("transformer-postprocess"); + output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb); + + return output; + } }; class EncoderTransformer : public Transformer<EncoderBase> { @@ -738,6 +773,13 @@ public: query, selfMask, startPos); + else if(layerType == "rnn") + query = DecoderLayerRNN(decoderState, + prevDecoderState, + prefix_ + "_l" + std::to_string(i) + "_rnn", + query, + selfMask, + startPos); else ABORT("Unknown auto-regressive layer type in transformer decoder {}", layerType); diff --git a/src/rnn/cells.h b/src/rnn/cells.h index 67750ac4..d245f891 100644 --- a/src/rnn/cells.h +++ b/src/rnn/cells.h @@ -108,6 +108,104 @@ public: /******************************************************************************/ +class ReLU : public Cell { +private: + Expr U_, W_, b_; + Expr gamma1_; + Expr gamma2_; + + bool layerNorm_; + float dropout_; + + Expr dropMaskX_; + Expr dropMaskS_; + +public: + ReLU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) { + int dimInput = options_->get<int>("dimInput"); + int dimState = options_->get<int>("dimState"); + std::string prefix = options_->get<std::string>("prefix"); + + layerNorm_ = options_->get<bool>("layer-normalization", false); + dropout_ = options_->get<float>("dropout", 0); + + U_ = graph->param(prefix + "_U", + {dimState, dimState}, + inits::diag(1.f)); + + if(dimInput) + W_ = graph->param(prefix + "_W", + {dimInput, dimState}, + inits::glorot_uniform); + + b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros); + + if(dropout_ > 0.0f) { + if(dimInput) + dropMaskX_ = graph->dropout(dropout_, {1, dimInput}); + dropMaskS_ = graph->dropout(dropout_, {1, dimState}); + } + + if(layerNorm_) { + if(dimInput) + gamma1_ = graph->param(prefix + "_gamma1", + {1, dimState}, + inits::ones); + gamma2_ = graph->param(prefix + "_gamma2", + {1, dimState}, + inits::ones); + } + } + + State apply(std::vector<Expr> inputs, State states, Expr mask = nullptr) { + return applyState(applyInput(inputs), states, mask); + } + + std::vector<Expr> applyInput(std::vector<Expr> inputs) { + Expr input; + if(inputs.size() == 0) + return {}; + else if(inputs.size() > 1) + input = concatenate(inputs, keywords::axis = -1); + else + input = inputs.front(); + + if(dropMaskX_) + input = dropout(input, dropMaskX_); + + auto xW = dot(input, W_); + + if(layerNorm_) + xW = layerNorm(xW, gamma1_); + + return {xW}; + } + + State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) { + Expr recState = state.output; + + auto stateDropped = recState; + if(dropMaskS_) + stateDropped = dropout(recState, dropMaskS_); + auto sU = dot(stateDropped, U_); + if(layerNorm_) + sU = layerNorm(sU, gamma2_); + + Expr output; + if(xWs.empty()) + output = relu(sU + b_); + else { + output = relu(xWs.front() + sU + b_); + } + if(mask) + return {output * mask, state.cell}; + else + return {output, state.cell}; + } +}; + +/******************************************************************************/ + Expr gruOps(const std::vector<Expr>& nodes, bool final = false); class GRU : public Cell { @@ -802,5 +900,275 @@ public: return {nextRecState, nextCellState}; } }; + +class SRU : public Cell { +private: + Expr W_; + Expr Wr_, br_; + Expr Wf_, bf_; + + float dropout_; + Expr dropMaskX_; + + float layerNorm_; + Expr gamma_, gammaf_, gammar_; + +public: + SRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) { + int dimInput = opt<int>("dimInput"); + int dimState = opt<int>("dimState"); + std::string prefix = opt<std::string>("prefix"); + + ABORT_IF(dimInput != dimState, "For SRU state and input dims have to be equal"); + + dropout_ = opt<float>("dropout", 0); + + W_ = graph->param(prefix + "_W", + {dimInput, dimInput}, + inits::glorot_uniform); + + Wf_ = graph->param(prefix + "_Wf", + {dimInput, dimInput}, + inits::glorot_uniform); + bf_ = graph->param( + prefix + "_bf", {1, dimInput}, inits::zeros); + + Wr_ = graph->param(prefix + "_Wr", + {dimInput, dimInput}, + inits::glorot_uniform); + br_ = graph->param( + prefix + "_br", {1, dimInput}, inits::zeros); + + if(dropout_ > 0.0f) { + dropMaskX_ = graph->dropout(dropout_, {1, dimInput}); + } + + if(layerNorm_) { + if(dimInput) + gamma_ = graph->param(prefix + "_gamma", + {1, dimState}, + inits::ones); + gammar_ = graph->param(prefix + "_gammar", + {1, dimState}, + inits::ones); + gammaf_ = graph->param(prefix + "_gammaf", + {1, dimState}, + inits::ones); + } + } + + State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) { + return applyState(applyInput(inputs), state, mask); + } + + std::vector<Expr> applyInput(std::vector<Expr> inputs) { + ABORT_IF(inputs.empty(), "SRU expects input"); + + Expr input; + if(inputs.size() > 1) + input = concatenate(inputs, keywords::axis = -1); + else + input = inputs.front(); + + auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input; + + Expr x, f, r; + if(layerNorm_) { + x = layerNorm(dot(inputDropped, W_), gamma_); + f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_); + r = layerNorm(dot(inputDropped, Wr_), gammar_, br_); + } else { + x = dot(inputDropped, W_); + f = affine(inputDropped, Wf_, bf_); + r = affine(inputDropped, Wr_, br_); + } + + return {x, f, r, input}; + } + + State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) { + auto recState = state.output; + auto cellState = state.cell; + + auto x = xWs[0]; + auto f = xWs[1]; + auto r = xWs[2]; + auto input = xWs[3]; + + auto nextCellState = highway(cellState, x, f); // rename to "gate"? + auto nextState = highway(tanh(nextCellState), input, r); + + auto maskedCellState = mask ? mask * nextCellState : nextCellState; + auto maskedState = mask ? mask * nextState : nextState; + + return {maskedState, maskedCellState}; + } +}; + +class SSRU : public Cell { +private: + Expr W_; + Expr Wf_, bf_; + + float dropout_; + Expr dropMaskX_; + + float layerNorm_; + Expr gamma_, gammaf_; + +public: + SSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) { + int dimInput = options_->get<int>("dimInput"); + int dimState = options_->get<int>("dimState"); + std::string prefix = options->get<std::string>("prefix"); + + ABORT_IF(dimInput != dimState, "For SSRU state and input dims have to be equal"); + + dropout_ = opt<float>("dropout", 0); + + W_ = graph->param(prefix + "_W", + {dimInput, dimInput}, + inits::glorot_uniform); + + Wf_ = graph->param(prefix + "_Wf", + {dimInput, dimInput}, + inits::glorot_uniform); + bf_ = graph->param( + prefix + "_bf", {1, dimInput}, inits::zeros); + + + if(dropout_ > 0.0f) { + dropMaskX_ = graph->dropout(dropout_, {1, dimInput}); + } + + if(layerNorm_) { + if(dimInput) + gamma_ = graph->param(prefix + "_gamma", + {1, dimState}, + inits::ones); + gammaf_ = graph->param(prefix + "_gammaf", + {1, dimState}, + inits::ones); + } + } + + State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) { + return applyState(applyInput(inputs), state, mask); + } + + std::vector<Expr> applyInput(std::vector<Expr> inputs) { + ABORT_IF(inputs.empty(), "SSRU expects input"); + + Expr input; + if(inputs.size() > 1) + input = concatenate(inputs, keywords::axis = -1); + else + input = inputs.front(); + + auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input; + + Expr x, f; + if(layerNorm_) { + x = layerNorm(dot(inputDropped, W_), gamma_); + f = layerNorm(dot(inputDropped, Wf_), gammaf_, bf_); + } else { + x = dot(inputDropped, W_); + f = affine(inputDropped, Wf_, bf_); + } + + return {x, f}; + } + + State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) { + auto recState = state.output; + auto cellState = state.cell; + + auto x = xWs[0]; + auto f = xWs[1]; + + auto nextCellState = highway(cellState, x, f); // rename to "gate"? + auto nextState = relu(nextCellState); + + auto maskedCellState = mask ? mask * nextCellState : nextCellState; + auto maskedState = mask ? mask * nextState : nextState; + + return {maskedState, maskedCellState}; + } +}; + +// class LSSRU : public Cell { +// private: +// Expr W_; +// Expr Wf_, bf_; + +// float dropout_; +// Expr dropMaskX_; + +// public: +// LSSRU(Ptr<ExpressionGraph> graph, Ptr<Options> options) : Cell(options) { +// int dimInput = options_->get<int>("dimInput"); +// int dimState = options_->get<int>("dimState"); +// std::string prefix = options->get<std::string>("prefix"); + +// ABORT_IF(dimInput != dimState, "For SRU state and input dims have to be equal"); + +// dropout_ = opt<float>("dropout", 0); + +// W_ = graph->param(prefix + "_W", +// {dimInput, dimInput}, +// inits::glorot_uniform); + +// Wf_ = graph->param(prefix + "_Wf", +// {dimInput, dimInput}, +// inits::glorot_uniform); +// bf_ = graph->param( +// prefix + "_bf", {1, dimInput}, inits::zeros); + + +// if(dropout_ > 0.0f) { +// dropMaskX_ = graph->dropout(dropout_, {1, dimInput}); +// } +// } + +// State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) { +// return applyState(applyInput(inputs), state, mask); +// } + +// std::vector<Expr> applyInput(std::vector<Expr> inputs) { +// ABORT_IF(inputs.empty(), "Slow SRU expects input"); + +// Expr input; +// if(inputs.size() > 1) +// input = concatenate(inputs, keywords::axis = -1); +// else +// input = inputs.front(); + +// auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input; + +// auto x = dot(inputDropped, W_); +// auto f = affine(inputDropped, Wf_, bf_); + +// return {x, f}; +// } + +// State applyState(std::vector<Expr> xWs, State state, Expr mask = nullptr) { +// auto recState = state.output; +// auto cellState = state.cell; + +// auto x = xWs[0]; +// auto f = xWs[1]; + +// auto nextCellState = highwayLinear(cellState, x, f, 2.f); // rename to "gate"? +// auto nextState = relu(nextCellState); +// //auto nextState = nextCellState; + +// auto maskedCellState = mask ? mask * nextCellState : nextCellState; +// auto maskedState = mask ? mask * nextState : nextState; + +// return {maskedState, maskedCellState}; +// } +// }; + + } // namespace rnn } // namespace marian diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h index 8aad788a..30351dc2 100644 --- a/src/rnn/constructors.h +++ b/src/rnn/constructors.h @@ -63,6 +63,18 @@ public: auto cell = New<Tanh>(graph_, options_); cell->setLazyInputs(inputs_); return cell; + } else if(type == "relu") { + auto cell = New<ReLU>(graph_, options_); + cell->setLazyInputs(inputs_); + return cell; + } else if(type == "sru") { + auto cell = New<SRU>(graph_, options_); + cell->setLazyInputs(inputs_); + return cell; + } else if(type == "ssru") { + auto cell = New<SSRU>(graph_, options_); + cell->setLazyInputs(inputs_); + return cell; } else { ABORT("Unknown RNN cell type"); } |