From 79368b121f29a0c12beada95473655220ead8827 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sat, 26 Jan 2019 14:14:57 -0800 Subject: add gelu activation --- src/graph/expression_operators.cpp | 9 +++++++++ src/graph/expression_operators.h | 3 +++ src/graph/node_operators_unary.h | 35 ++++++++++++++++++++++++++++------- src/layers/loss.h | 1 + src/models/bert.h | 26 +++++++++++++------------- src/models/transformer.h | 2 ++ src/tensors/gpu/add.inc | 1 + src/tensors/gpu/element.inc | 2 ++ 8 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index c0640edb..e558ffd0 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -51,6 +51,10 @@ Expr swish(Expr a) { return Expression(a); } +Expr gelu(Expr a) { + return Expression(a, 1.702f); +} + Expr operator-(Expr a) { return Expression(a); }; @@ -529,6 +533,11 @@ Expr swish(const std::vector& nodes) { return swish(nodes[0]); } +Expr gelu(const std::vector& nodes) { + ABORT_IF(nodes.size() > 1, "Not implemented"); + return gelu(nodes[0]); +} + Expr tanh(const std::vector& nodes) { return Expression(nodes); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 78aed834..0205fe86 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -17,6 +17,9 @@ Expr sigmoid(const std::vector&); Expr swish(Expr a); Expr swish(const std::vector&); +Expr gelu(Expr a); +Expr gelu(const std::vector&); + Expr tanh(const std::vector&); template diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 63756da3..190fa947 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -343,31 +343,52 @@ private: * in an expression graph. * * This node implements the activation function - * \f$ f(x) = x \cdot \sigma(x) \f$ + * \f$ f(x) = x \cdot \sigma(bx) \f$ * and its derivative - * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ . * */ struct SwishNodeOp : public UnaryNodeOp { - SwishNodeOp(Expr a) : UnaryNodeOp(a) {} + SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; - // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), + // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) + return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x - val_ // _3 := f(x) = x*sigma(x) + val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, b_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + Ptr cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(b_ != cnode->b_) + return false; + return true; + } + + float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { diff --git a/src/layers/loss.h b/src/layers/loss.h index 4a28de7b..2ac4ae78 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -331,6 +331,7 @@ protected: virtual Expr compute(Expr logits, Expr labelIndices, Expr mask = nullptr, Expr labelWeights = nullptr) override { + logits = atleast_3d(logits); // safeguard against 2d classifier output, adds 1 on the left, non-op. Expr ce = cross_entropy(logits, labelIndices); if(labelSmoothing_ > 0) { diff --git a/src/models/bert.h b/src/models/bert.h index c642abac..b4d2a34c 100644 --- a/src/models/bert.h +++ b/src/models/bert.h @@ -144,7 +144,7 @@ public: int dimBatch = subBatch->batchSize(); int dimWords = subBatch->batchWidth(); - int maxSentPos = 2; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2] + int maxSentPos = 1; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2] // If another separator is seen do not increase position index beyond 2 but use padding. // @TODO: make this configurable, see below for NextSentencePredictions task where we also restrict to 2. @@ -231,7 +231,7 @@ public: if(learnedPosEmbeddings) { auto sentenceEmbeddings = embedding() ("prefix", "Wsent") - ("dimVocab", 3) // sentence A or sentence B plus padding, @TODO: should rather be a parameter + ("dimVocab", 2) // sentence A or sentence B plus padding, @TODO: should rather be a parameter ("dimEmb", dimEmb) .construct(graph_); signal = sentenceEmbeddings->apply(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb}); @@ -327,24 +327,24 @@ public: int dimVoc = opt>("dim-vocabs")[batchIndex_]; - std::string activationType = opt("transformer-ffn-activation"); - mlp::act activation; - if(activationType == "relu") - activation = mlp::act::ReLU; - else if(activationType == "swish") - activation = mlp::act::swish; - else - ABORT("Activation function {} not supported in BERT masked LM", activationType); - auto layer1 = mlp::mlp() .push_back(mlp::dense() ("prefix", prefix_ + "_ff_logit_l1") - ("dim", dimModel) - ("activation", activation)) + ("dim", dimModel)) .construct(graph); auto intermediate = layer1->apply(maskedContext); + std::string activationType = opt("transformer-ffn-activation"); + if(activationType == "relu") + intermediate = relu(intermediate); + else if(activationType == "swish") + intermediate = swish(intermediate); + else if(activationType == "gelu") + intermediate = gelu(intermediate); + else + ABORT("Activation function {} not supported in BERT masked LM", activationType); + auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones); auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros); intermediate = layerNorm(intermediate, gamma, beta); diff --git a/src/models/transformer.h b/src/models/transformer.h index 01201df9..02ce2f6e 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -375,6 +375,8 @@ public: return (ActivationFunction*)relu; else if (actName == "swish") return (ActivationFunction*)swish; + else if (actName == "gelu") + return (ActivationFunction*)gelu; ABORT("Invalid activation name '{}'", actName); } diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc index 69244dce..b70c59d0 100755 --- a/src/tensors/gpu/add.inc +++ b/src/tensors/gpu/add.inc @@ -32,3 +32,4 @@ template void marian::gpu::Add, marian::functional::UnaryFunctor, marian::functional::Assignee<3> > > >, std::shared_ptr, std::shared_ptr, std::shared_ptr >(marian::functional::BinaryFunctor, marian::functional::UnaryFunctor, marian::functional::Assignee<3> > > >, float, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr); template void marian::gpu::Add, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, std::shared_ptr, std::shared_ptr, std::shared_ptr >(marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, float, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr); template void marian::gpu::Add, marian::functional::Capture>, marian::functional::Assignee<2> >, std::shared_ptr, std::shared_ptr >(marian::functional::BinaryFunctor, marian::functional::Capture>, marian::functional::Assignee<2> >, float, std::shared_ptr, std::shared_ptr, std::shared_ptr); +template void marian::gpu::Add, marian::functional::BinaryFunctor >, marian::functional::BinaryFunctor > >, marian::functional::BinaryFunctor > > > > >, std::shared_ptr, std::shared_ptr, std::shared_ptr >(marian::functional::BinaryFunctor, marian::functional::BinaryFunctor >, marian::functional::BinaryFunctor > >, marian::functional::BinaryFunctor > > > > >, float, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr); \ No newline at end of file diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index f3cdea28..364866fa 100755 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -56,6 +56,8 @@ template void Element, BinaryFunctor, Bin template void Element, BinaryFunctor, BinaryFunctor, Capture>, BinaryFunctor, Capture> >, Capture> >, BinaryFunctor > > > > >, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr >(Assign, BinaryFunctor, BinaryFunctor, Capture>, BinaryFunctor, Capture> >, Capture> >, BinaryFunctor > > > > >, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr, std::shared_ptr); template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr, std::shared_ptr >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr, std::shared_ptr, std::shared_ptr); template void marian::gpu::Element, marian::functional::UnaryFunctor > >>(marian::functional::Assign, marian::functional::UnaryFunctor > >, std::shared_ptr); +template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > > >, std::shared_ptr >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > > >, std::shared_ptr, std::shared_ptr); + // How to add new specializations: // When you use a new specialization, it will cause a link error of this form (example): // .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element ( ... )' -- cgit v1.2.3