Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-09-28 20:19:07 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-09-28 20:19:07 +0300
commit03fe1758763c99dd55bcf6c1c5e0e1dd60ae4e1a (patch)
tree24044b77c9e6d320dfb25d9eecb576367d42e41e
parentd796a3c3b7779993660e672f2a47f5cdd685a174 (diff)
Merged PR 20879: Adjustable ffn width and depth in transformer decoder
-rw-r--r--src/common/config_parser.cpp8
-rw-r--r--src/models/encoder_decoder.cpp2
-rw-r--r--src/models/transformer.h21
3 files changed, 24 insertions, 7 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index d7818afb..b3e8950b 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -255,10 +255,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)",
- 2048);
+ 2048);
+ cli.add<int>("--transformer-decoder-dim-ffn",
+ "Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
+ 0);
cli.add<int>("--transformer-ffn-depth",
"Depth of filters (transformer)",
2);
+ cli.add<int>("--transformer-decoder-ffn-depth",
+ "Depth of filters in decoder (transformer). Uses --transformer-ffn-depth if 0",
+ 0);
cli.add<std::string>("--transformer-ffn-activation",
"Activation between filters: swish or relu (transformer)",
"swish");
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 8fc9321a..a7a398e7 100644
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
@@ -38,7 +38,9 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-dim-ffn");
+ modelFeatures_.insert("transformer-decoder-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
+ modelFeatures_.insert("transformer-decoder-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
modelFeatures_.insert("transformer-dim-aan");
modelFeatures_.insert("transformer-aan-depth");
diff --git a/src/models/transformer.h b/src/models/transformer.h
index a792de8b..2393ad73 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -400,7 +400,7 @@ public:
opt<int>("transformer-heads"), /*cache=*/false);
}
- Expr LayerFFN(std::string prefix, Expr input) const {
+ Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const {
int dimModel = input->shape()[-1];
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
@@ -408,13 +408,22 @@ public:
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
auto actName = opt<std::string>("transformer-ffn-activation");
+
int dimFfn = opt<int>("transformer-dim-ffn");
int depthFfn = opt<int>("transformer-ffn-depth");
- float ffnDropProb
- = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
-
+ if(isDecoder) {
+ int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0);
+ if(decDimFfn != 0)
+ dimFfn = decDimFfn;
+
+ int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0);
+ if(decDepthFfn != 0)
+ depthFfn = decDepthFfn;
+ }
+
ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn);
-
+
+ float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
// the stack of FF layers
@@ -861,7 +870,7 @@ public:
// remember decoder state
decoderStates.push_back(decoderState);
- query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+ query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
checkpoint(query);
}