diff options
Diffstat (limited to 'src/models/transformer.h')
-rw-r--r-- | src/models/transformer.h | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h index ec68b801..95a55d3a 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -148,8 +148,7 @@ public: int dimDepth = dimModel / dimHeads; - auto output - = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth}); + auto output = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth}); return transpose(output, {0, 2, 1, 3}); // [dimBatch*dimBeam, dimHeads, dimSteps, dimDepth] } @@ -364,9 +363,9 @@ public: Expr LayerAttention(std::string prefix, Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim] - const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] - const Expr& values, // ...? - const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] + Expr keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + Expr values, // ...? + Expr mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] int dimHeads, bool cache = false, bool saveAttentionWeights = false) { @@ -376,6 +375,12 @@ public: auto opsPre = opt<std::string>("transformer-preprocess"); auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb); + // fixes missing norm for keys and values in self-attention with pre-norm + if(input == keys) + keys = output; + if(input == values) + values = output; + // multi-head self-attention over previous input output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights); @@ -403,7 +408,7 @@ public: opt<int>("transformer-heads"), /*cache=*/false); } - Expr LayerFFN(std::string prefix, Expr input) const { + Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const { int dimModel = input->shape()[-1]; float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); @@ -411,13 +416,22 @@ public: auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb); auto actName = opt<std::string>("transformer-ffn-activation"); + int dimFfn = opt<int>("transformer-dim-ffn"); int depthFfn = opt<int>("transformer-ffn-depth"); - float ffnDropProb - = inference_ ? 0 : opt<float>("transformer-dropout-ffn"); - + if(isDecoder) { + int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0); + if(decDimFfn != 0) + dimFfn = decDimFfn; + + int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0); + if(decDepthFfn != 0) + depthFfn = decDepthFfn; + } + ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn); - + + float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn"); auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f); // the stack of FF layers @@ -866,7 +880,7 @@ public: // remember decoder state decoderStates.push_back(decoderState); - query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] checkpoint(query); } |