From 03fe1758763c99dd55bcf6c1c5e0e1dd60ae4e1a Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Tue, 28 Sep 2021 17:19:07 +0000 Subject: Merged PR 20879: Adjustable ffn width and depth in transformer decoder --- src/common/config_parser.cpp | 8 +++++++- src/models/encoder_decoder.cpp | 2 ++ src/models/transformer.h | 21 +++++++++++++++------ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index d7818afb..b3e8950b 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -255,10 +255,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { "Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)"); cli.add("--transformer-dim-ffn", "Size of position-wise feed-forward network (transformer)", - 2048); + 2048); + cli.add("--transformer-decoder-dim-ffn", + "Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.", + 0); cli.add("--transformer-ffn-depth", "Depth of filters (transformer)", 2); + cli.add("--transformer-decoder-ffn-depth", + "Depth of filters in decoder (transformer). Uses --transformer-ffn-depth if 0", + 0); cli.add("--transformer-ffn-activation", "Activation between filters: swish or relu (transformer)", "swish"); diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 8fc9321a..a7a398e7 100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -38,7 +38,9 @@ EncoderDecoder::EncoderDecoder(Ptr graph, Ptr options) modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); modelFeatures_.insert("transformer-dim-ffn"); + modelFeatures_.insert("transformer-decoder-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); + modelFeatures_.insert("transformer-decoder-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); modelFeatures_.insert("transformer-dim-aan"); modelFeatures_.insert("transformer-aan-depth"); diff --git a/src/models/transformer.h b/src/models/transformer.h index a792de8b..2393ad73 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -400,7 +400,7 @@ public: opt("transformer-heads"), /*cache=*/false); } - Expr LayerFFN(std::string prefix, Expr input) const { + Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const { int dimModel = input->shape()[-1]; float dropProb = inference_ ? 0 : opt("transformer-dropout"); @@ -408,13 +408,22 @@ public: auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb); auto actName = opt("transformer-ffn-activation"); + int dimFfn = opt("transformer-dim-ffn"); int depthFfn = opt("transformer-ffn-depth"); - float ffnDropProb - = inference_ ? 0 : opt("transformer-dropout-ffn"); - + if(isDecoder) { + int decDimFfn = opt("transformer-decoder-dim-ffn", 0); + if(decDimFfn != 0) + dimFfn = decDimFfn; + + int decDepthFfn = opt("transformer-decoder-ffn-depth", 0); + if(decDepthFfn != 0) + depthFfn = decDepthFfn; + } + ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn); - + + float ffnDropProb = inference_ ? 0 : opt("transformer-dropout-ffn"); auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f); // the stack of FF layers @@ -861,7 +870,7 @@ public: // remember decoder state decoderStates.push_back(decoderState); - query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] + query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] checkpoint(query); } -- cgit v1.2.3