diff options
author | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2019-01-27 01:14:57 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com> | 2019-01-27 01:14:57 +0300 |
commit | 79368b121f29a0c12beada95473655220ead8827 (patch) | |
tree | 17b39372649b08ed2c2943a58562295a346f4d27 | |
parent | 50d64de62cb3e47b912fdd3438ffacb566187263 (diff) |
add gelu activation
-rw-r--r-- | src/graph/expression_operators.cpp | 9 | ||||
-rwxr-xr-x | src/graph/expression_operators.h | 3 | ||||
-rwxr-xr-x | src/graph/node_operators_unary.h | 35 | ||||
-rw-r--r-- | src/layers/loss.h | 1 | ||||
-rw-r--r-- | src/models/bert.h | 26 | ||||
-rw-r--r-- | src/models/transformer.h | 2 | ||||
-rwxr-xr-x | src/tensors/gpu/add.inc | 1 | ||||
-rwxr-xr-x | src/tensors/gpu/element.inc | 2 |
8 files changed, 59 insertions, 20 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index c0640edb..e558ffd0 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -51,6 +51,10 @@ Expr swish(Expr a) { return Expression<SwishNodeOp>(a); } +Expr gelu(Expr a) { + return Expression<SwishNodeOp>(a, 1.702f); +} + Expr operator-(Expr a) { return Expression<NegNodeOp>(a); }; @@ -529,6 +533,11 @@ Expr swish(const std::vector<Expr>& nodes) { return swish(nodes[0]); } +Expr gelu(const std::vector<Expr>& nodes) { + ABORT_IF(nodes.size() > 1, "Not implemented"); + return gelu(nodes[0]); +} + Expr tanh(const std::vector<Expr>& nodes) { return Expression<TanhNodeOp>(nodes); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 78aed834..0205fe86 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -17,6 +17,9 @@ Expr sigmoid(const std::vector<Expr>&); Expr swish(Expr a); Expr swish(const std::vector<Expr>&); +Expr gelu(Expr a); +Expr gelu(const std::vector<Expr>&); + Expr tanh(const std::vector<Expr>&); template <typename... Args> diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 63756da3..190fa947 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -343,31 +343,52 @@ private: * in an expression graph. * * This node implements the activation function - * \f$ f(x) = x \cdot \sigma(x) \f$ + * \f$ f(x) = x \cdot \sigma(bx) \f$ * and its derivative - * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ . * */ struct SwishNodeOp : public UnaryNodeOp { - SwishNodeOp(Expr a) : UnaryNodeOp(a) {} + SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {} NodeOps forwardOps() override { using namespace functional; - return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))}; + return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))}; } NodeOps backwardOps() override { using namespace functional; - // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)), + // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x))) + return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x - val_ // _3 := f(x) = x*sigma(x) + val_ // _3 := f(x) = x*sigmoid(b*x) ))}; } const std::string type() override { return "swish"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, b_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node); + if(!cnode) + return false; + if(b_ != cnode->b_) + return false; + return true; + } + + float b_; }; struct SoftmaxNodeOp : public UnaryNodeOp { diff --git a/src/layers/loss.h b/src/layers/loss.h index 4a28de7b..2ac4ae78 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -331,6 +331,7 @@ protected: virtual Expr compute(Expr logits, Expr labelIndices, Expr mask = nullptr, Expr labelWeights = nullptr) override { + logits = atleast_3d(logits); // safeguard against 2d classifier output, adds 1 on the left, non-op. Expr ce = cross_entropy(logits, labelIndices); if(labelSmoothing_ > 0) { diff --git a/src/models/bert.h b/src/models/bert.h index c642abac..b4d2a34c 100644 --- a/src/models/bert.h +++ b/src/models/bert.h @@ -144,7 +144,7 @@ public: int dimBatch = subBatch->batchSize(); int dimWords = subBatch->batchWidth(); - int maxSentPos = 2; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2] + int maxSentPos = 1; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2] // If another separator is seen do not increase position index beyond 2 but use padding. // @TODO: make this configurable, see below for NextSentencePredictions task where we also restrict to 2. @@ -231,7 +231,7 @@ public: if(learnedPosEmbeddings) { auto sentenceEmbeddings = embedding() ("prefix", "Wsent") - ("dimVocab", 3) // sentence A or sentence B plus padding, @TODO: should rather be a parameter + ("dimVocab", 2) // sentence A or sentence B plus padding, @TODO: should rather be a parameter ("dimEmb", dimEmb) .construct(graph_); signal = sentenceEmbeddings->apply(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb}); @@ -327,24 +327,24 @@ public: int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; - std::string activationType = opt<std::string>("transformer-ffn-activation"); - mlp::act activation; - if(activationType == "relu") - activation = mlp::act::ReLU; - else if(activationType == "swish") - activation = mlp::act::swish; - else - ABORT("Activation function {} not supported in BERT masked LM", activationType); - auto layer1 = mlp::mlp() .push_back(mlp::dense() ("prefix", prefix_ + "_ff_logit_l1") - ("dim", dimModel) - ("activation", activation)) + ("dim", dimModel)) .construct(graph); auto intermediate = layer1->apply(maskedContext); + std::string activationType = opt<std::string>("transformer-ffn-activation"); + if(activationType == "relu") + intermediate = relu(intermediate); + else if(activationType == "swish") + intermediate = swish(intermediate); + else if(activationType == "gelu") + intermediate = gelu(intermediate); + else + ABORT("Activation function {} not supported in BERT masked LM", activationType); + auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones); auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros); intermediate = layerNorm(intermediate, gamma, beta); diff --git a/src/models/transformer.h b/src/models/transformer.h index 01201df9..02ce2f6e 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -375,6 +375,8 @@ public: return (ActivationFunction*)relu; else if (actName == "swish") return (ActivationFunction*)swish; + else if (actName == "gelu") + return (ActivationFunction*)gelu; ABORT("Invalid activation name '{}'", actName); } diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc index 69244dce..b70c59d0 100755 --- a/src/tensors/gpu/add.inc +++ b/src/tensors/gpu/add.inc @@ -32,3 +32,4 @@ template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functio template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); +template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
\ No newline at end of file diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index f3cdea28..364866fa 100755 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -56,6 +56,8 @@ template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, Bin template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture> >, Capture> >, BinaryFunctor<elem::Mult, Capture, Assignee<1> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture> >, Capture> >, BinaryFunctor<elem::Mult, Capture, Assignee<1> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >, std::shared_ptr<marian::TensorBase>); +template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, std::shared_ptr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); + // How to add new specializations: // When you use a new specialization, it will cause a link error of this form (example): // .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )' |