Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/graph/expression_operators.cu16
-rw-r--r--src/graph/expression_operators.h6
-rw-r--r--src/graph/node_operators_unary.h104
-rw-r--r--src/kernels/tensor_operators.cu7
-rw-r--r--src/kernels/thrust_functions.h38
-rw-r--r--src/layers/generic.h6
-rw-r--r--src/tests/attention_tests.cpp7
-rw-r--r--src/training/validator.h5
-rw-r--r--src/translator/output_collector.cpp6
-rw-r--r--src/translator/output_collector.h30
10 files changed, 196 insertions, 29 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index deed4815..98570862 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -20,6 +20,14 @@ Expr relu(Expr a) {
return Expression<ReLUNodeOp>(a);
}
+Expr leakyrelu(Expr a) {
+ return Expression<PReLUNodeOp>(0.01f, a);
+}
+
+Expr prelu(Expr a, float alpha) {
+ return Expression<PReLUNodeOp>(alpha, a);
+}
+
Expr log(Expr a) {
return Expression<LogNodeOp>(a);
};
@@ -238,6 +246,14 @@ Expr relu(const std::vector<Expr>&) {
ABORT("Not implemented");
}
+Expr leakyrelu(const std::vector<Expr>&) {
+ ABORT("Not implemented");
+}
+
+Expr prelu(const std::vector<Expr>&, float alpha) {
+ ABORT("Not implemented");
+}
+
Expr sqrt(Expr a, float eps) {
return Expression<SqrtNodeOp>(a, eps);
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 8daf8ef9..f6721b52 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -24,6 +24,12 @@ Expr tanh(Args... args) {
Expr relu(Expr a);
Expr relu(const std::vector<Expr>&);
+Expr leakyrelu(Expr a);
+Expr leakyrelu(const std::vector<Expr>&);
+
+Expr prelu(Expr a, float alpha = 0.01);
+Expr prelu(const std::vector<Expr>&, float alpha = 0.01);
+
Expr log(Expr a);
Expr exp(Expr a);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 4457ad2e..844f8905 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -186,38 +186,98 @@ struct TanhNodeOp : public NaryNodeOp {
/**
* Represents a <a
-href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
-linear</a> node
- * in an expression graph.
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified
+ * linear</a> node in an expression graph.
*
- * This node implements the <a
-href="https://en.wikipedia.org/wiki/Activation_function">activation function</a>
- * \f$f(x) = \max(0, x)\f$ and its derivative:
- *
- \f[
- f^\prime(x) =
- \begin{cases}
- 0 & \text{if } x \leq 0 \\
- 1 & \text{if } x > 0
- \end{cases}
-\f]
+ * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and
+ * its derivative:
+ * \f[
+ * f^\prime(x) =
+ * \begin{cases}
+ * 0 & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
*/
struct ReLUNodeOp : public UnaryNodeOp {
template <typename... Args>
ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))};
+ // f(x) = max(0, x)
+ return {NodeOp(Element(_1 = ReLU(_2),
+ val_, // _1 := f(x) to be calculated
+ child(0)->val() // _2 := x
+ ))};
}
NodeOps backwardOps() {
- return {NodeOp(
- Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))};
+ // dJ/dx += dJ/df * binarystep(x)
+ return {NodeOp(Add(_1 * ReLUback(_2),
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val() // _2 := f(x) = max(0, x)
+ ))};
}
const std::string type() { return "ReLU"; }
};
+/**
+ * Represents a <a
+ * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric
+ * rectified linear unit</a> node in an expression graph.
+ * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky
+ * ReLU.
+ *
+ * This node implements the activation function:
+ * \f[
+ * f(x, \alpha) =
+ * \begin{cases}
+ * \alpha x & \text{if } x \leq 0 \\
+ * x & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ *
+ * and its derivative:
+ * \f[
+ * f^\prime(x, \alpha) =
+ * \begin{cases}
+ * \alpha & \text{if } x \leq 0 \\
+ * 1 & \text{if } x > 0
+ * \end{cases}
+ * \f]
+ */
+struct PReLUNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ PReLUNodeOp(float alpha, Args... args)
+ : UnaryNodeOp(args...), alpha_(alpha) {}
+
+ NodeOps forwardOps() {
+ return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ return {NodeOp(Add(
+ _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
+ }
+
+ const std::string type() { return "PReLU"; }
+
+private:
+ float alpha_{0.01};
+};
+
+/**
+ * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node
+ * in an expression graph.
+ *
+ * This node implements the activation function
+ * \f$ f(x) = x \cdot \sigma(x) \f$
+ * and its derivative
+ * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ *
+ */
struct SwishNodeOp : public UnaryNodeOp {
template <typename... Args>
SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
@@ -227,11 +287,13 @@ struct SwishNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
- child(0)->grad(),
- adj_,
- child(0)->val(),
- val_))};
+ child(0)->grad(), // dJ/dx
+ adj_, // _1 := dJ/df
+ child(0)->val(), // _2 := x
+ val_ // _3 := f(x) = x*sigma(x)
+ ))};
}
const std::string type() { return "swish"; }
diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu
index 1e4b4c92..fc38042c 100644
--- a/src/kernels/tensor_operators.cu
+++ b/src/kernels/tensor_operators.cu
@@ -88,10 +88,11 @@ void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) {
int cols_out = out->shape().back();
for(auto in : inputs) {
- UTIL_THROW_IF2(rows != in->shape().elements() / in->shape().back(),
+ ABORT_IF(rows != in->shape().elements() / in->shape().back(),
"First dimension must be equal");
int cols_in = in->shape().back();
+
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols_in);
@@ -116,8 +117,8 @@ void Split1(std::vector<Tensor>& outputs, const Tensor in) {
int rows = in->shape().elements() / in->shape().back();
int cols_in = in->shape().back();
for(auto out : outputs) {
- UTIL_THROW_IF2(rows != out->shape().elements() / out->shape().back(),
- "First dimension must be equal");
+ ABORT_IF(rows != out->shape().elements() / out->shape().back(),
+ "First dimension must be equal");
int cols_out = out->shape().back();
int blocks = std::min(MAX_BLOCKS, rows);
diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h
index 67f37a13..1d91fc38 100644
--- a/src/kernels/thrust_functions.h
+++ b/src/kernels/thrust_functions.h
@@ -80,6 +80,8 @@ __host__ __device__
binary_operator<thrust::maximum>(), make_actor(_1), make_actor(_2));
}
+//*******************************************************************
+
template <typename T>
struct unary_relu : public thrust::unary_function<T, T> {
__host__ __device__ T operator()(const T &x) const {
@@ -107,6 +109,42 @@ __host__ __device__
return compose(unary_operator<unary_reluback>(), _1);
}
+//*******************************************************************
+
+template <typename T>
+struct binary_prelu : public thrust::binary_function<T, T, T> {
+ __host__ __device__ T operator()(const T &x, const T &alpha) const {
+ return x > 0.0f ? x : alpha * x;
+ }
+};
+
+template <typename T1, typename T2>
+__host__ __device__ actor<composite<binary_operator<binary_prelu>,
+ actor<T1>,
+ typename as_actor<T2>::type>>
+PReLU(const actor<T1> &_1, const T2 &_2) {
+ return compose(
+ binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2));
+}
+
+template <typename T>
+struct binary_preluback : public thrust::binary_function<T, T, T> {
+ __host__ __device__ T operator()(const T &x, const T &alpha) const {
+ return x > 0.0f ? 1.0f : alpha;
+ }
+};
+
+template <typename T1, typename T2>
+__host__ __device__ actor<composite<binary_operator<binary_preluback>,
+ actor<T1>,
+ typename as_actor<T2>::type>>
+PReLUback(const actor<T1> &_1, const T2 &_2) {
+ return compose(
+ binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2));
+}
+
+//*******************************************************************
+
template <typename T>
__host__ __device__ int sgn(T val) {
return (float(0) < val) - (val < float(0));
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 0cca4a38..5f465bd8 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -9,7 +9,7 @@
namespace marian {
namespace mlp {
-enum struct act : int { linear, tanh, logit, ReLU, swish };
+enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, PReLU, swish };
}
}
@@ -128,6 +128,8 @@ public:
case act::tanh: return tanh(outputs);
case act::logit: return logit(outputs);
case act::ReLU: return relu(outputs);
+ case act::LeakyReLU: return leakyrelu(outputs);
+ case act::PReLU: return prelu(outputs);
case act::swish: return swish(outputs);
default: return plus(outputs);
}
@@ -186,6 +188,8 @@ public:
case act::tanh: return tanh(out);
case act::logit: return logit(out);
case act::ReLU: return relu(out);
+ case act::LeakyReLU: return leakyrelu(out);
+ case act::PReLU: return prelu(out);
case act::swish: return swish(out);
default: return out;
}
diff --git a/src/tests/attention_tests.cpp b/src/tests/attention_tests.cpp
index d6eeae6e..722b82a7 100644
--- a/src/tests/attention_tests.cpp
+++ b/src/tests/attention_tests.cpp
@@ -94,6 +94,13 @@ TEST_CASE("Model components, Attention", "[attention]") {
});
aligned->val()->get(values);
+
+ //for(int i = 0; i < values.size(); ++i) {
+ // if(i && i % 4 == 0)
+ // std::cout << std::endl;
+ // std::cout << values[i] << ", ";
+ //}
+
CHECK( std::equal(values.begin(), values.end(),
vAligned.begin(), floatApprox) );
}
diff --git a/src/training/validator.h b/src/training/validator.h
index ec803d5d..9e4fdeec 100644
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -232,7 +232,7 @@ public:
fileName = tempFile->getFileName();
}
- LOG_VALID(info, "Translating validation set...");
+ LOG(info, "Translating validation set...");
graph->setInference(true);
boost::timer::cpu_timer timer;
@@ -241,6 +241,7 @@ public:
auto collector = options_->has("trans-output")
? New<OutputCollector>(fileName)
: New<OutputCollector>(*tempFile);
+ collector->setPrintingStrategy(New<GeometricPrinting>());
size_t sentenceId = 0;
@@ -266,7 +267,7 @@ public:
}
}
- LOG_VALID(info, "Total translation time: {}", timer.format(5, "%ws"));
+ LOG(info, "Total translation time: {}", timer.format(5, "%ws"));
graph->setInference(false);
float val = 0.0f;
diff --git a/src/translator/output_collector.cpp b/src/translator/output_collector.cpp
index 9ab17668..86fc4fad 100644
--- a/src/translator/output_collector.cpp
+++ b/src/translator/output_collector.cpp
@@ -14,7 +14,8 @@ void OutputCollector::Write(long sourceId,
bool nbest) {
boost::mutex::scoped_lock lock(mutex_);
if(sourceId == nextId_) {
- LOG(info, "Best translation {} : {}", sourceId, best1);
+ if(printing_ && printing_->shouldBePrinted(sourceId))
+ LOG(info, "Best translation {} : {}", sourceId, best1);
if(nbest)
((std::ostream&)*outStrm_) << bestn << std::endl;
@@ -31,7 +32,8 @@ void OutputCollector::Write(long sourceId,
if(currId == nextId_) {
// 1st element in the map is the next
const auto& currOutput = iter->second;
- LOG(info, "Best translation {} : {}", currId, currOutput.first);
+ if(printing_ && printing_->shouldBePrinted(sourceId))
+ LOG(info, "Best translation {} : {}", currId, currOutput.first);
if(nbest)
((std::ostream&)*outStrm_) << currOutput.second << std::endl;
else
diff --git a/src/translator/output_collector.h b/src/translator/output_collector.h
index 479cfe6e..91904297 100644
--- a/src/translator/output_collector.h
+++ b/src/translator/output_collector.h
@@ -10,6 +10,30 @@
namespace marian {
+class PrintingStrategy {
+public:
+ virtual bool shouldBePrinted(long) = 0;
+};
+
+class GeometricPrinting : public PrintingStrategy {
+public:
+ bool shouldBePrinted(long id) {
+ if(id == 0)
+ next_ = start_;
+ if(id <= 5)
+ return true;
+ if(next_ == id) {
+ next_ += next_;
+ return true;
+ }
+ return false;
+ }
+
+private:
+ size_t start_{10};
+ long next_{10};
+};
+
class OutputCollector {
public:
OutputCollector();
@@ -24,6 +48,10 @@ public:
const std::string& bestn,
bool nbest);
+ void setPrintingStrategy(Ptr<PrintingStrategy> strategy) {
+ printing_ = strategy;
+ }
+
protected:
UPtr<OutputFileStream> outStrm_;
boost::mutex mutex_;
@@ -31,6 +59,8 @@ protected:
typedef std::map<long, std::pair<std::string, std::string>> Outputs;
Outputs outputs_;
+
+ Ptr<PrintingStrategy> printing_;
};
class StringCollector {