diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-09-28 20:19:07 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2021-09-28 20:19:07 +0300 |
commit | 03fe1758763c99dd55bcf6c1c5e0e1dd60ae4e1a (patch) | |
tree | 24044b77c9e6d320dfb25d9eecb576367d42e41e | |
parent | d796a3c3b7779993660e672f2a47f5cdd685a174 (diff) |
Merged PR 20879: Adjustable ffn width and depth in transformer decoder
-rw-r--r-- | src/common/config_parser.cpp | 8 | ||||
-rw-r--r-- | src/models/encoder_decoder.cpp | 2 | ||||
-rw-r--r-- | src/models/transformer.h | 21 |
3 files changed, 24 insertions, 7 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index d7818afb..b3e8950b 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -255,10 +255,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)"); cli.add<int>("--transformer-dim-ffn", "Size of position-wise feed-forward network (transformer)", - 2048); + 2048); + cli.add<int>("--transformer-decoder-dim-ffn", + "Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.", + 0); cli.add<int>("--transformer-ffn-depth", "Depth of filters (transformer)", 2); + cli.add<int>("--transformer-decoder-ffn-depth", + "Depth of filters in decoder (transformer). Uses --transformer-ffn-depth if 0", + 0); cli.add<std::string>("--transformer-ffn-activation", "Activation between filters: swish or relu (transformer)", "swish"); diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 8fc9321a..a7a398e7 100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -38,7 +38,9 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options) modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); modelFeatures_.insert("transformer-dim-ffn"); + modelFeatures_.insert("transformer-decoder-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); + modelFeatures_.insert("transformer-decoder-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); modelFeatures_.insert("transformer-dim-aan"); modelFeatures_.insert("transformer-aan-depth"); diff --git a/src/models/transformer.h b/src/models/transformer.h index a792de8b..2393ad73 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -400,7 +400,7 @@ public: opt<int>("transformer-heads"), /*cache=*/false); } - Expr LayerFFN(std::string prefix, Expr input) const { + Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const { int dimModel = input->shape()[-1]; float dropProb = inference_ ? 0 : opt<float>("transformer-dropout"); @@ -408,13 +408,22 @@ public: auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb); auto actName = opt<std::string>("transformer-ffn-activation"); + int dimFfn = opt<int>("transformer-dim-ffn"); int depthFfn = opt<int>("transformer-ffn-depth"); - float ffnDropProb - = inference_ ? 0 : opt<float>("transformer-dropout-ffn"); - + if(isDecoder) { + int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0); + if(decDimFfn != 0) + dimFfn = decDimFfn; + + int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0); + if(decDepthFfn != 0) + depthFfn = decDepthFfn; + } + ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn); - + + float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn"); auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f); // the stack of FF layers @@ -861,7 +870,7 @@ public: // remember decoder state decoderStates.push_back(decoderState); - query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] checkpoint(query); } |