Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/models/transformer.h')
-rw-r--r--src/models/transformer.h36
1 files changed, 25 insertions, 11 deletions
diff --git a/src/models/transformer.h b/src/models/transformer.h
index ec68b801..95a55d3a 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -148,8 +148,7 @@ public:
int dimDepth = dimModel / dimHeads;
- auto output
- = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});
+ auto output = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});
return transpose(output, {0, 2, 1, 3}); // [dimBatch*dimBeam, dimHeads, dimSteps, dimDepth]
}
@@ -364,9 +363,9 @@ public:
Expr LayerAttention(std::string prefix,
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
- const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
- const Expr& values, // ...?
- const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
+ Expr keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+ Expr values, // ...?
+ Expr mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
int dimHeads,
bool cache = false,
bool saveAttentionWeights = false) {
@@ -376,6 +375,12 @@ public:
auto opsPre = opt<std::string>("transformer-preprocess");
auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb);
+ // fixes missing norm for keys and values in self-attention with pre-norm
+ if(input == keys)
+ keys = output;
+ if(input == values)
+ values = output;
+
// multi-head self-attention over previous input
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);
@@ -403,7 +408,7 @@ public:
opt<int>("transformer-heads"), /*cache=*/false);
}
- Expr LayerFFN(std::string prefix, Expr input) const {
+ Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const {
int dimModel = input->shape()[-1];
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
@@ -411,13 +416,22 @@ public:
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
auto actName = opt<std::string>("transformer-ffn-activation");
+
int dimFfn = opt<int>("transformer-dim-ffn");
int depthFfn = opt<int>("transformer-ffn-depth");
- float ffnDropProb
- = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
-
+ if(isDecoder) {
+ int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0);
+ if(decDimFfn != 0)
+ dimFfn = decDimFfn;
+
+ int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0);
+ if(decDepthFfn != 0)
+ depthFfn = decDepthFfn;
+ }
+
ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn);
-
+
+ float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
// the stack of FF layers
@@ -866,7 +880,7 @@ public:
// remember decoder state
decoderStates.push_back(decoderState);
- query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+ query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
checkpoint(query);
}