Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-01-27 01:14:57 +0300
committerMarcin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2019-01-27 01:14:57 +0300
commit79368b121f29a0c12beada95473655220ead8827 (patch)
tree17b39372649b08ed2c2943a58562295a346f4d27
parent50d64de62cb3e47b912fdd3438ffacb566187263 (diff)
add gelu activation
-rw-r--r--src/graph/expression_operators.cpp9
-rwxr-xr-xsrc/graph/expression_operators.h3
-rwxr-xr-xsrc/graph/node_operators_unary.h35
-rw-r--r--src/layers/loss.h1
-rw-r--r--src/models/bert.h26
-rw-r--r--src/models/transformer.h2
-rwxr-xr-xsrc/tensors/gpu/add.inc1
-rwxr-xr-xsrc/tensors/gpu/element.inc2
8 files changed, 59 insertions, 20 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index c0640edb..e558ffd0 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -51,6 +51,10 @@ Expr swish(Expr a) {
return Expression<SwishNodeOp>(a);
}
+Expr gelu(Expr a) {
+ return Expression<SwishNodeOp>(a, 1.702f);
+}
+
Expr operator-(Expr a) {
return Expression<NegNodeOp>(a);
};
@@ -529,6 +533,11 @@ Expr swish(const std::vector<Expr>& nodes) {
return swish(nodes[0]);
}
+Expr gelu(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return gelu(nodes[0]);
+}
+
Expr tanh(const std::vector<Expr>& nodes) {
return Expression<TanhNodeOp>(nodes);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 78aed834..0205fe86 100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -17,6 +17,9 @@ Expr sigmoid(const std::vector<Expr>&);
Expr swish(Expr a);
Expr swish(const std::vector<Expr>&);
+Expr gelu(Expr a);
+Expr gelu(const std::vector<Expr>&);
+
Expr tanh(const std::vector<Expr>&);
template <typename... Args>
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 63756da3..190fa947 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -343,31 +343,52 @@ private:
* in an expression graph.
*
* This node implements the activation function
- * \f$ f(x) = x \cdot \sigma(x) \f$
+ * \f$ f(x) = x \cdot \sigma(bx) \f$
* and its derivative
- * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ .
*
*/
struct SwishNodeOp : public UnaryNodeOp {
- SwishNodeOp(Expr a) : UnaryNodeOp(a) {}
+ SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {}
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
- // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)),
+ // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x)))
+ return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
- val_ // _3 := f(x) = x*sigma(x)
+ val_ // _3 := f(x) = x*sigmoid(b*x)
))};
}
const std::string type() override { return "swish"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ util::hash_combine(hash_, b_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ Ptr<SwishNodeOp> cnode = std::dynamic_pointer_cast<SwishNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(b_ != cnode->b_)
+ return false;
+ return true;
+ }
+
+ float b_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
diff --git a/src/layers/loss.h b/src/layers/loss.h
index 4a28de7b..2ac4ae78 100644
--- a/src/layers/loss.h
+++ b/src/layers/loss.h
@@ -331,6 +331,7 @@ protected:
virtual Expr compute(Expr logits, Expr labelIndices,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
+ logits = atleast_3d(logits); // safeguard against 2d classifier output, adds 1 on the left, non-op.
Expr ce = cross_entropy(logits, labelIndices);
if(labelSmoothing_ > 0) {
diff --git a/src/models/bert.h b/src/models/bert.h
index c642abac..b4d2a34c 100644
--- a/src/models/bert.h
+++ b/src/models/bert.h
@@ -144,7 +144,7 @@ public:
int dimBatch = subBatch->batchSize();
int dimWords = subBatch->batchWidth();
- int maxSentPos = 2; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
+ int maxSentPos = 1; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
// If another separator is seen do not increase position index beyond 2 but use padding.
// @TODO: make this configurable, see below for NextSentencePredictions task where we also restrict to 2.
@@ -231,7 +231,7 @@ public:
if(learnedPosEmbeddings) {
auto sentenceEmbeddings = embedding()
("prefix", "Wsent")
- ("dimVocab", 3) // sentence A or sentence B plus padding, @TODO: should rather be a parameter
+ ("dimVocab", 2) // sentence A or sentence B plus padding, @TODO: should rather be a parameter
("dimEmb", dimEmb)
.construct(graph_);
signal = sentenceEmbeddings->apply(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
@@ -327,24 +327,24 @@ public:
int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
- std::string activationType = opt<std::string>("transformer-ffn-activation");
- mlp::act activation;
- if(activationType == "relu")
- activation = mlp::act::ReLU;
- else if(activationType == "swish")
- activation = mlp::act::swish;
- else
- ABORT("Activation function {} not supported in BERT masked LM", activationType);
-
auto layer1 = mlp::mlp()
.push_back(mlp::dense()
("prefix", prefix_ + "_ff_logit_l1")
- ("dim", dimModel)
- ("activation", activation))
+ ("dim", dimModel))
.construct(graph);
auto intermediate = layer1->apply(maskedContext);
+ std::string activationType = opt<std::string>("transformer-ffn-activation");
+ if(activationType == "relu")
+ intermediate = relu(intermediate);
+ else if(activationType == "swish")
+ intermediate = swish(intermediate);
+ else if(activationType == "gelu")
+ intermediate = gelu(intermediate);
+ else
+ ABORT("Activation function {} not supported in BERT masked LM", activationType);
+
auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones);
auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros);
intermediate = layerNorm(intermediate, gamma, beta);
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 01201df9..02ce2f6e 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -375,6 +375,8 @@ public:
return (ActivationFunction*)relu;
else if (actName == "swish")
return (ActivationFunction*)swish;
+ else if (actName == "gelu")
+ return (ActivationFunction*)gelu;
ABORT("Invalid activation name '{}'", actName);
}
diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc
index 69244dce..b70c59d0 100755
--- a/src/tensors/gpu/add.inc
+++ b/src/tensors/gpu/add.inc
@@ -32,3 +32,4 @@ template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functio
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::Exp, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<2>, marian::functional::Assignee<3> > > >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> >, marian::functional::Assignee<3> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
+template void marian::gpu::Add<marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<3> > > > > >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>); \ No newline at end of file
diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc
index f3cdea28..364866fa 100755
--- a/src/tensors/gpu/element.inc
+++ b/src/tensors/gpu/element.inc
@@ -56,6 +56,8 @@ template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, Bin
template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture> >, Capture> >, BinaryFunctor<elem::Mult, Capture, Assignee<1> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture> >, Capture> >, BinaryFunctor<elem::Mult, Capture, Assignee<1> > > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >, std::shared_ptr<marian::TensorBase>);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, std::shared_ptr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
+
// How to add new specializations:
// When you use a new specialization, it will cause a link error of this form (example):
// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'