diff options
author | Frank Seide <fseide@microsoft.com> | 2018-06-22 02:48:40 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2018-06-22 02:48:40 +0300 |
commit | c450ca9fb0e0407af20b61148ef9c67360ba66a9 (patch) | |
tree | 2ca04b11f9a2cdcbdec19c1052188003d7f0076d /src | |
parent | b0940e0c487bcd39a30aa520d024ffb203c564a1 (diff) |
further small refactoring in transformer; renamed layer_norm to layerNorm
Diffstat (limited to 'src')
-rw-r--r-- | src/graph/expression_operators.cpp | 8 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 14 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 2 | ||||
-rw-r--r-- | src/layers/generic.h | 26 | ||||
-rw-r--r-- | src/models/transformer.h | 155 | ||||
-rw-r--r-- | src/rnn/attention.h | 8 | ||||
-rw-r--r-- | src/rnn/cells.h | 28 |
7 files changed, 112 insertions, 129 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 2b524b86..1666357a 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -410,10 +410,10 @@ Expr square(Expr a) { return Expression<SquareNodeOp>(a); } -Expr layer_norm(Expr x, - Expr gamma, - Expr beta /*= nullptr*/, - float eps /*= 1e-9*/) { +Expr layerNorm(Expr x, + Expr gamma, + Expr beta /*= nullptr*/, + float eps /*= 1e-9*/) { std::vector<Expr> nodes = {x, gamma}; if(beta) nodes.push_back(beta); diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index f5f5bbb2..bdc30f63 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -128,7 +128,7 @@ Expr step(Expr a, int step, int axis); Expr sqrt(Expr a, float eps = 0.f); Expr square(Expr a); -Expr layer_norm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9); +Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9); Expr highway(Expr y, Expr x, Expr t); Expr highway(const std::string prefix, Expr x); @@ -137,14 +137,18 @@ static inline Expr dropout(Expr x, Expr mask) { return x * mask; } -static inline Expr dropout(Expr x, float prob, Shape shape) { +static inline Expr dropout(Expr x, float dropProb, Shape shape) { + if (dropProb == 0) + return x; auto graph = x->graph(); - auto mask = graph->dropout(prob, shape); + auto mask = graph->dropout(dropProb, shape); return dropout(x, mask); } -static inline Expr dropout(Expr x, float prob) { - return dropout(x, prob, x->shape()); +static inline Expr dropout(Expr x, float dropProb) { + if (dropProb == 0) + return x; + return dropout(x, dropProb, x->shape()); } Expr shift(Expr, Shape, float padValue = 0); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 9cf839df..c1315304 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -932,8 +932,10 @@ public: Shape outShape = a->shape(); axis_ = outShape.axis(axis); +#if 0 // this check currently fails in translation; I think should not fail for step==0 for(int i = 0; i < axis_; ++i) ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()"); +#endif outShape.set(axis_, 1); return outShape; diff --git a/src/layers/generic.h b/src/layers/generic.h index 1b989709..8b19123b 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -50,8 +50,8 @@ public: auto name = opt<std::string>("prefix"); auto dim = opt<int>("dim"); - auto layerNorm = opt<bool>("layer-normalization", false); - auto nematusNorm = opt<bool>("nematus-normalization", false); + auto useLayerNorm = opt<bool>("layer-normalization", false); + auto useNematusNorm = opt<bool>("nematus-normalization", false); auto activation = opt<act>("activation", act::linear); auto g = graph_; @@ -71,8 +71,8 @@ public: {1, dim}, inits::zeros); - if(layerNorm) { - if(nematusNorm) { + if(useLayerNorm) { + if(useNematusNorm) { auto ln_s = g->param(name + "_ln_s" + num, {1, dim}, inits::from_value(1.f)); @@ -80,13 +80,13 @@ public: {1, dim}, inits::zeros); - outputs.push_back(layer_norm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); + outputs.push_back(layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS)); } else { auto gamma = g->param(name + "_gamma" + num, {1, dim}, inits::from_value(1.0)); - outputs.push_back(layer_norm(dot(in, W), gamma, b)); + outputs.push_back(layerNorm(dot(in, W), gamma, b)); } } else { @@ -96,14 +96,14 @@ public: } switch(activation) { - case act::linear: return plus(outputs); - case act::tanh: return tanh(outputs); - case act::sigmoid: return sigmoid(outputs); - case act::ReLU: return relu(outputs); + case act::linear: return plus(outputs); + case act::tanh: return tanh(outputs); + case act::sigmoid: return sigmoid(outputs); + case act::ReLU: return relu(outputs); case act::LeakyReLU: return leakyrelu(outputs); - case act::PReLU: return prelu(outputs); - case act::swish: return swish(outputs); - default: return plus(outputs); + case act::PReLU: return prelu(outputs); + case act::swish: return swish(outputs); + default: return plus(outputs); } }; diff --git a/src/models/transformer.h b/src/models/transformer.h index fccf2fde..107145eb 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -99,55 +99,65 @@ public: return reshape(output, {dimBeam, dimBatch, dimSteps, dimModel}); } - Expr PreProcess(std::string prefix, std::string ops, Expr input, float dropProb = 0.0f) const { - int dimModel = input->shape()[-1]; + // like affine() but with built-in parameters, activation, and dropout + static inline + Expr dense(Expr x, std::string prefix, std::string suffix, int outDim, const std::function<Expr(Expr)>& actFn = nullptr, float dropProb = 0.0f) + { + auto graph = x->graph(); + + auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorot_uniform); + auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros); + + x = affine(x, W, b); + if (actFn) + x = actFn(x); + if (dropProb) + x = dropout(x, dropProb); + return x; + } + + Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) const { + int dimModel = x->shape()[-1]; + auto scale = graph_->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones); + auto bias = graph_->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros); + return marian::layerNorm(x, scale, bias, 1e-6); + } + + Expr preProcess(std::string prefix, std::string ops, Expr input, float dropProb = 0.0f) const { auto output = input; for(auto op : ops) { // dropout - if(op == 'd' && dropProb > 0.0f) { + if (op == 'd' && dropProb > 0.0f) output = dropout(output, dropProb); - } // layer normalization - if(op == 'n') { - auto scale = graph_->param( - prefix + "_ln_scale_pre", {1, dimModel}, inits::ones); - auto bias = graph_->param( - prefix + "_ln_bias_pre", {1, dimModel}, inits::zeros); - output = layer_norm(output, scale, bias, 1e-6); - } + else if (op == 'n') + output = layerNorm(output, prefix, "_pre"); + else + ABORT("Unknown pre-processing operation '%c'", op); } return output; } - Expr PostProcess(std::string prefix, std::string ops, Expr input, Expr prevInput, float dropProb = 0.0f) const { - int dimModel = input->shape()[-1]; + Expr postProcess(std::string prefix, std::string ops, Expr input, Expr prevInput, float dropProb = 0.0f) const { auto output = input; for(auto op : ops) { // dropout - if(op == 'd' && dropProb > 0.0f) { + if(op == 'd' && dropProb > 0.0f) output = dropout(output, dropProb); - } // skip connection - if(op == 'a') { + else if(op == 'a') output = output + prevInput; - } // highway connection - if(op == 'h') { - auto Wh = graph_->param( - prefix + "_Wh", {dimModel, dimModel}, inits::glorot_uniform); - auto bh = graph_->param(prefix + "_bh", {1, dimModel}, inits::zeros); - - auto t = affine(prevInput, Wh, bh); + else if(op == 'h') { + int dimModel = input->shape()[-1]; + auto t = dense(prevInput, prefix, /*suffix=*/"h", dimModel); output = highway(output, prevInput, t); } // layer normalization - if(op == 'n') { - auto scale - = graph_->param(prefix + "_ln_scale", {1, dimModel}, inits::ones); - auto bias - = graph_->param(prefix + "_ln_bias", {1, dimModel}, inits::zeros); - output = layer_norm(output, scale, bias, 1e-6); - } + else if(op == 'n') + output = layerNorm(output, prefix); + else + ABORT("Unknown pre-processing operation '%c'", op); } return output; } @@ -160,8 +170,6 @@ public: Expr v, // [-4: batch size, -3: num heads, -2: max src length, -1: split vector dim] Expr values, // [-4: beam depth, -3: batch size, -2: max kv length, -1: vector dim] Expr mask = nullptr) const { // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] - using namespace keywords; - int dk = k->shape()[-1]; // softmax over batched dot product of query and keys (applied over all @@ -172,8 +180,8 @@ public: int dimBeamK = k->shape()[-4]; int dimBeam = dimBeamQ / dimBeamK; if(dimBeam > 1) { // broadcast k and v into all beam elements --TODO: if we use a separate dimension, then this would be automatic at no memory cost - k = repeat(k, dimBeam, axis = -4); // [-4: beam depth * batch size, -3: num heads, -2: max src length, -1: split vector dim] - v = repeat(v, dimBeam, axis = -4); // [-4: beam depth * batch size, -3: num heads, -2: max src length, -1: split vector dim] + k = repeat(k, dimBeam, /*axis=*/-4); // [-4: beam depth * batch size, -3: num heads, -2: max src length, -1: split vector dim] + v = repeat(v, dimBeam, /*axis=*/-4); // [-4: beam depth * batch size, -3: num heads, -2: max src length, -1: split vector dim] } // now q, k, and v have the same first dims [-4: beam depth * batch size, -3: num heads, -2: max src or tgt length, -1: split vector dim] @@ -190,13 +198,10 @@ public: // optional dropout for attention weights float dropProb = inference_ ? 0 : opt<float>("transformer-dropout-attention"); - - if(dropProb) - weights = dropout(weights, dropProb); + weights = dropout(weights, dropProb); // apply attention weights to values auto output = bdot(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim] - return output; } @@ -281,7 +286,7 @@ public: float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); auto opsPre = opt<std::string>("transformer-preprocess"); - auto output = PreProcess(prefix + "_Wo", opsPre, input, dropProb); + auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb); auto heads = opt<int>("transformer-heads"); @@ -289,8 +294,7 @@ public: output = MultiHead(prefix, dimModel, heads, output, keys, values, masks); auto opsPost = opt<std::string>("transformer-postprocess"); - output - = PostProcess(prefix + "_Wo", opsPost, output, input, dropProb); + output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb); return output; } @@ -301,15 +305,11 @@ public: Expr input, Expr selfMask, int startPos) const { - - using namespace keywords; - selfMask = transposedLogMask(selfMask); auto values = input; if(startPos > 0) { - values = concatenate({prevDecoderState.output, input}, - axis = -2); + values = concatenate({prevDecoderState.output, input}, /*axis=*/-2); } decoderState.output = values; @@ -327,31 +327,12 @@ public: ABORT("Invalid activation name '{}'", actName); } - // like affine() but with built-in parameters, activation, and dropout - static inline - Expr dense(Expr x, std::string prefix, int i, int outDim, const std::function<Expr(Expr)>& actFn = nullptr, float dropProb = 0.0f) - { - auto graph = x->graph(); - - auto W = graph->param(prefix + "_W" + std::to_string(i), { x->shape()[-1], outDim }, inits::glorot_uniform); - auto b = graph->param(prefix + "_b" + std::to_string(i), { 1, outDim }, inits::zeros); - - x = affine(x, W, b); - if (actFn) - x = actFn(x); - if (dropProb) - x = dropout(x, dropProb); - return x; - } - Expr LayerFFN(std::string prefix, Expr input) const { - using namespace keywords; - int dimModel = input->shape()[-1]; float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); auto opsPre = opt<std::string>("transformer-preprocess"); - auto output = PreProcess(prefix + "_ffn", opsPre, input, dropProb); + auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb); int dimFfn = opt<int>("transformer-dim-ffn"); int depthFfn = opt<int>("transformer-ffn-depth"); @@ -363,12 +344,12 @@ public: // the stack of FF layers for(int i = 1; i < depthFfn; ++i) - output = dense(output, prefix, i, dimFfn, actFn, ffnDropProb); - output = dense(output, prefix, depthFfn, dimModel); + output = dense(output, prefix, /*suffix=*/std::to_string(i), dimFfn, actFn, ffnDropProb); + output = dense(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel); auto opsPost = opt<std::string>("transformer-postprocess"); output - = PostProcess(prefix + "_ffn", opsPost, output, input, dropProb); + = postProcess(prefix + "_ffn", opsPost, output, input, dropProb); return output; } @@ -381,7 +362,7 @@ public: float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); auto opsPre = opt<std::string>("transformer-preprocess"); - y = PreProcess(prefix + "_ffn", opsPre, y, dropProb); + y = preProcess(prefix + "_ffn", opsPre, y, dropProb); // FFN int dimAan = opt<int>("transformer-dim-aan"); @@ -431,12 +412,12 @@ public: } auto opsPost = opt<std::string>("transformer-postprocess"); - y = PostProcess(prefix + "_ffn", opsPost, y, x, dropProb); + y = postProcess(prefix + "_ffn", opsPost, y, x, dropProb); return y; } - // Implementation of Average Attention Network Layer (ANN) from + // Implementation of Average Attention Network Layer (AAN) from // https://arxiv.org/pdf/1805.00631.pdf // Function wrapper using decoderState as input. Expr DecoderLayerAAN(rnn::State& decoderState, @@ -445,9 +426,6 @@ public: Expr input, Expr selfMask, int startPos) const { - - using namespace keywords; - auto output = input; if(startPos > 0) { // we are decoding at a position after 0 @@ -457,7 +435,7 @@ public: // we are training or scoring, because there is no history and // the context is larger than a single time step. We do not need // to average batch with only single words. - selfMask = selfMask / sum(selfMask, axis=-1); + selfMask = selfMask / sum(selfMask, /*axis=*/-1); output = bdot(selfMask, output); } decoderState.output = output; // BUGBUG: mutable? @@ -470,11 +448,12 @@ class EncoderTransformer : public Transformer<EncoderBase> { public: EncoderTransformer(Ptr<Options> options) : Transformer(options) {} - // returns the embedding matrix - Expr WordEmbeddings() const { + // returns the embedding matrix based on options + // And based on batchIndex_. + Expr wordEmbeddings(int subBatchIndex) const { // standard encoder word embeddings - int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; + int dimVoc = opt<std::vector<int>>("dim-vocabs")[subBatchIndex]; int dimEmb = opt<int>("dim-emb"); auto embFactory = embedding(graph_)("dimVocab", dimVoc)("dimEmb", dimEmb); @@ -490,7 +469,7 @@ public: if(options_->has("embedding-vectors")) { auto embFiles = opt<std::vector<std::string>>("embedding-vectors"); embFactory // - ("embFile", embFiles[batchIndex_]) // + ("embFile", embFiles[subBatchIndex]) // ("normalization", opt<bool>("embedding-normalization")); } @@ -508,7 +487,7 @@ public: int dimBatch = batch->size(); int dimSrcWords = (*batch)[batchIndex_]->batchWidth(); - auto embeddings = WordEmbeddings(); // embedding matrix, considering tying and some other options + auto embeddings = wordEmbeddings(batchIndex_); // embedding matrix, considering tying and some other options // embed the source words in the batch Expr batchEmbeddings, batchMask; @@ -537,7 +516,7 @@ public: auto opsEmb = opt<std::string>("transformer-postprocess-emb"); float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); - layer = PreProcess(prefix_ + "_emb", opsEmb, layer, dropProb); + layer = preProcess(prefix_ + "_emb", opsEmb, layer, dropProb); layerMask = transposedLogMask(layerMask); // [-4: batch size, -3: 1, -2: vector dim=1, -1: max length] @@ -601,7 +580,6 @@ public: }; class DecoderTransformer : public Transformer<DecoderBase> { -protected: private: Ptr<mlp::MLP> output_; @@ -656,7 +634,7 @@ public: Ptr<DecoderState> step(Ptr<DecoderState> state) const { auto embeddings = state->getTargetEmbeddings(); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim] - auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis + auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis // dropout target words float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg"); @@ -691,7 +669,7 @@ public: auto opsEmb = opt<std::string>("transformer-postprocess-emb"); float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); - query = PreProcess(prefix_ + "_emb", opsEmb, query, dropProb); + query = preProcess(prefix_ + "_emb", opsEmb, query, dropProb); int dimTrgWords = query->shape()[-2]; int dimBatch = query->shape()[-3]; @@ -741,13 +719,12 @@ public: // self-attention std::string layerType = opt<std::string>("transformer-decoder-autoreg"); - if(layerType == "self-attention") { + if(layerType == "self-attention") query = DecoderLayerSelfAttention(decoderState, prevDecoderState, prefix_ + "_l" + std::to_string(i) + "_self", query, selfMask, startPos); - } else if(layerType == "average-attention") { + else if(layerType == "average-attention") query = DecoderLayerAAN(decoderState, prevDecoderState, prefix_ + "_l" + std::to_string(i) + "_aan", query, selfMask, startPos); - } else { + else ABORT("Unknown auto-regressive layer type in transformer decoder {}", layerType); - } decoderStates.push_back(decoderState); diff --git a/src/rnn/attention.h b/src/rnn/attention.h index fc09e9b7..92a89e99 100644 --- a/src/rnn/attention.h +++ b/src/rnn/attention.h @@ -80,7 +80,7 @@ public: W_comb_att_lnb_ = graph->param( prefix + "_W_comb_att_lnb", {1, dimEncState}, inits::zeros); - mappedContext_ = layer_norm(affine(contextDropped_, Ua_, ba_), + mappedContext_ = layerNorm(affine(contextDropped_, Ua_, ba_), Wc_att_lns_, Wc_att_lnb_, NEMATUS_LN_EPS); @@ -91,7 +91,7 @@ public: prefix + "_att_gamma2", {1, dimEncState}, inits::from_value(1.0)); mappedContext_ - = layer_norm(dot(contextDropped_, Ua_), gammaContext_, ba_); + = layerNorm(dot(contextDropped_, Ua_), gammaContext_, ba_); } } else { @@ -121,10 +121,10 @@ public: auto mappedState = dot(recState, Wa_); if(layerNorm_) if(nematusNorm_) - mappedState = layer_norm( + mappedState = layerNorm( mappedState, W_comb_att_lns_, W_comb_att_lnb_, NEMATUS_LN_EPS); else - mappedState = layer_norm(mappedState, gammaState_); + mappedState = layerNorm(mappedState, gammaState_); auto attReduce = attOps(va_, mappedContext_, mappedState); diff --git a/src/rnn/cells.h b/src/rnn/cells.h index c26825d2..f3299144 100644 --- a/src/rnn/cells.h +++ b/src/rnn/cells.h @@ -81,7 +81,7 @@ public: auto xW = dot(input, W_); if(layerNorm_) - xW = layer_norm(xW, gamma1_); + xW = layerNorm(xW, gamma1_); return {xW}; } @@ -94,7 +94,7 @@ public: stateDropped = dropout(recState, dropMaskS_); auto sU = dot(stateDropped, U_); if(layerNorm_) - sU = layer_norm(sU, gamma2_); + sU = layerNorm(sU, gamma2_); Expr output; if(xWs.empty()) @@ -207,7 +207,7 @@ public: auto xW = dot(input, W_); if(layerNorm_) - xW = layer_norm(xW, gamma1_); + xW = layerNorm(xW, gamma1_); return {xW}; } @@ -222,7 +222,7 @@ public: auto sU = dot(stateDropped, U_); if(layerNorm_) - sU = layer_norm(sU, gamma2_); + sU = layerNorm(sU, gamma2_); Expr xW; if(xWs.empty()) { @@ -406,8 +406,8 @@ public: W = affine(input, W_, b_); Wx = affine(input, Wx_, bx_); } - W = layer_norm(W, W_lns_, W_lnb_, NEMATUS_LN_EPS); - Wx = layer_norm(Wx, Wx_lns_, Wx_lnb_, NEMATUS_LN_EPS); + W = layerNorm(W, W_lns_, W_lnb_, NEMATUS_LN_EPS); + Wx = layerNorm(Wx, Wx_lns_, Wx_lnb_, NEMATUS_LN_EPS); xW = concatenate({W, Wx}, keywords::axis = -1); } else { @@ -434,8 +434,8 @@ public: Expr Ux; // Temp_2_ in Amun if(encoder_) { - U = layer_norm(dot(stateDropped, U_), U_lns_, U_lnb_, NEMATUS_LN_EPS); - Ux = layer_norm( + U = layerNorm(dot(stateDropped, U_), U_lns_, U_lnb_, NEMATUS_LN_EPS); + Ux = layerNorm( dot(stateDropped, Ux_), Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS); if(transition_) { @@ -449,8 +449,8 @@ public: U = dot(stateDropped, U_); Ux = dot(stateDropped, Ux_); } - U = layer_norm(U, U_lns_, U_lnb_, NEMATUS_LN_EPS); - Ux = layer_norm(Ux, Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS); + U = layerNorm(U, U_lns_, U_lnb_, NEMATUS_LN_EPS); + Ux = layerNorm(Ux, Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS); } sU = concatenate({U, Ux}, keywords::axis = -1); @@ -555,7 +555,7 @@ public: auto xW = dot(input, W_); if(layerNorm_) - xW = layer_norm(xW, gamma1_); + xW = layerNorm(xW, gamma1_); return {xW}; } @@ -573,7 +573,7 @@ public: auto sU = dot(recStateDropped, U_); if(layerNorm_) - sU = layer_norm(sU, gamma2_); + sU = layerNorm(sU, gamma2_); Expr xW; if(xWs.empty()) { @@ -648,7 +648,7 @@ public: auto xWs = CellType::applyInput({input}); auto xWm = affine(input, Wm_, bwm_); if(CellType::layerNorm_) - xWm = layer_norm(xWm, gamma1m_); + xWm = layerNorm(xWm, gamma1m_); xWs.push_back(xWm); return xWs; @@ -662,7 +662,7 @@ public: auto sUm = affine(state.output, Um_, bm_); if(CellType::layerNorm_) - sUm = layer_norm(sUm, gamma2m_); + sUm = layerNorm(sUm, gamma2m_); auto mstate = xWm * sUm; |