diff options
-rw-r--r-- | src/graph/expression_operators.cu | 16 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 6 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 104 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 7 | ||||
-rw-r--r-- | src/kernels/thrust_functions.h | 38 | ||||
-rw-r--r-- | src/layers/generic.h | 6 | ||||
-rw-r--r-- | src/tests/attention_tests.cpp | 7 | ||||
-rw-r--r-- | src/training/validator.h | 5 | ||||
-rw-r--r-- | src/translator/output_collector.cpp | 6 | ||||
-rw-r--r-- | src/translator/output_collector.h | 30 |
10 files changed, 196 insertions, 29 deletions
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index deed4815..98570862 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -20,6 +20,14 @@ Expr relu(Expr a) { return Expression<ReLUNodeOp>(a); } +Expr leakyrelu(Expr a) { + return Expression<PReLUNodeOp>(0.01f, a); +} + +Expr prelu(Expr a, float alpha) { + return Expression<PReLUNodeOp>(alpha, a); +} + Expr log(Expr a) { return Expression<LogNodeOp>(a); }; @@ -238,6 +246,14 @@ Expr relu(const std::vector<Expr>&) { ABORT("Not implemented"); } +Expr leakyrelu(const std::vector<Expr>&) { + ABORT("Not implemented"); +} + +Expr prelu(const std::vector<Expr>&, float alpha) { + ABORT("Not implemented"); +} + Expr sqrt(Expr a, float eps) { return Expression<SqrtNodeOp>(a, eps); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 8daf8ef9..f6721b52 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -24,6 +24,12 @@ Expr tanh(Args... args) { Expr relu(Expr a); Expr relu(const std::vector<Expr>&); +Expr leakyrelu(Expr a); +Expr leakyrelu(const std::vector<Expr>&); + +Expr prelu(Expr a, float alpha = 0.01); +Expr prelu(const std::vector<Expr>&, float alpha = 0.01); + Expr log(Expr a); Expr exp(Expr a); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 4457ad2e..844f8905 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -186,38 +186,98 @@ struct TanhNodeOp : public NaryNodeOp { /** * Represents a <a -href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified -linear</a> node - * in an expression graph. + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified + * linear</a> node in an expression graph. * - * This node implements the <a -href="https://en.wikipedia.org/wiki/Activation_function">activation function</a> - * \f$f(x) = \max(0, x)\f$ and its derivative: - * - \f[ - f^\prime(x) = - \begin{cases} - 0 & \text{if } x \leq 0 \\ - 1 & \text{if } x > 0 - \end{cases} -\f] + * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and + * its derivative: + * \f[ + * f^\prime(x) = + * \begin{cases} + * 0 & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] */ struct ReLUNodeOp : public UnaryNodeOp { template <typename... Args> ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = ReLU(_2), val_, child(0)->val()))}; + // f(x) = max(0, x) + return {NodeOp(Element(_1 = ReLU(_2), + val_, // _1 := f(x) to be calculated + child(0)->val() // _2 := x + ))}; } NodeOps backwardOps() { - return {NodeOp( - Add(_1 * ReLUback(_2), child(0)->grad(), adj_, child(0)->val()))}; + // dJ/dx += dJ/df * binarystep(x) + return {NodeOp(Add(_1 * ReLUback(_2), + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val() // _2 := f(x) = max(0, x) + ))}; } const std::string type() { return "ReLU"; } }; +/** + * Represents a <a + * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">parametric + * rectified linear unit</a> node in an expression graph. + * For \f$ \alpha = 0.01 \f$ (the default value) it is equivalent to Leaky + * ReLU. + * + * This node implements the activation function: + * \f[ + * f(x, \alpha) = + * \begin{cases} + * \alpha x & \text{if } x \leq 0 \\ + * x & \text{if } x > 0 + * \end{cases} + * \f] + * + * and its derivative: + * \f[ + * f^\prime(x, \alpha) = + * \begin{cases} + * \alpha & \text{if } x \leq 0 \\ + * 1 & \text{if } x > 0 + * \end{cases} + * \f] + */ +struct PReLUNodeOp : public UnaryNodeOp { + template <typename... Args> + PReLUNodeOp(float alpha, Args... args) + : UnaryNodeOp(args...), alpha_(alpha) {} + + NodeOps forwardOps() { + return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + return {NodeOp(Add( + _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() { return "PReLU"; } + +private: + float alpha_{0.01}; +}; + +/** + * Represents a <a href="https://arxiv.org/pdf/1710.05941.pdf">swish</a> node + * in an expression graph. + * + * This node implements the activation function + * \f$ f(x) = x \cdot \sigma(x) \f$ + * and its derivative + * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ . + * + */ struct SwishNodeOp : public UnaryNodeOp { template <typename... Args> SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -227,11 +287,13 @@ struct SwishNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), - child(0)->grad(), - adj_, - child(0)->val(), - val_))}; + child(0)->grad(), // dJ/dx + adj_, // _1 := dJ/df + child(0)->val(), // _2 := x + val_ // _3 := f(x) = x*sigma(x) + ))}; } const std::string type() { return "swish"; } diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index 1e4b4c92..fc38042c 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -88,10 +88,11 @@ void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) { int cols_out = out->shape().back(); for(auto in : inputs) { - UTIL_THROW_IF2(rows != in->shape().elements() / in->shape().back(), + ABORT_IF(rows != in->shape().elements() / in->shape().back(), "First dimension must be equal"); int cols_in = in->shape().back(); + int blocks = std::min(MAX_BLOCKS, rows); int threads = std::min(MAX_THREADS, cols_in); @@ -116,8 +117,8 @@ void Split1(std::vector<Tensor>& outputs, const Tensor in) { int rows = in->shape().elements() / in->shape().back(); int cols_in = in->shape().back(); for(auto out : outputs) { - UTIL_THROW_IF2(rows != out->shape().elements() / out->shape().back(), - "First dimension must be equal"); + ABORT_IF(rows != out->shape().elements() / out->shape().back(), + "First dimension must be equal"); int cols_out = out->shape().back(); int blocks = std::min(MAX_BLOCKS, rows); diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h index 67f37a13..1d91fc38 100644 --- a/src/kernels/thrust_functions.h +++ b/src/kernels/thrust_functions.h @@ -80,6 +80,8 @@ __host__ __device__ binary_operator<thrust::maximum>(), make_actor(_1), make_actor(_2)); } +//******************************************************************* + template <typename T> struct unary_relu : public thrust::unary_function<T, T> { __host__ __device__ T operator()(const T &x) const { @@ -107,6 +109,42 @@ __host__ __device__ return compose(unary_operator<unary_reluback>(), _1); } +//******************************************************************* + +template <typename T> +struct binary_prelu : public thrust::binary_function<T, T, T> { + __host__ __device__ T operator()(const T &x, const T &alpha) const { + return x > 0.0f ? x : alpha * x; + } +}; + +template <typename T1, typename T2> +__host__ __device__ actor<composite<binary_operator<binary_prelu>, + actor<T1>, + typename as_actor<T2>::type>> +PReLU(const actor<T1> &_1, const T2 &_2) { + return compose( + binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2)); +} + +template <typename T> +struct binary_preluback : public thrust::binary_function<T, T, T> { + __host__ __device__ T operator()(const T &x, const T &alpha) const { + return x > 0.0f ? 1.0f : alpha; + } +}; + +template <typename T1, typename T2> +__host__ __device__ actor<composite<binary_operator<binary_preluback>, + actor<T1>, + typename as_actor<T2>::type>> +PReLUback(const actor<T1> &_1, const T2 &_2) { + return compose( + binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2)); +} + +//******************************************************************* + template <typename T> __host__ __device__ int sgn(T val) { return (float(0) < val) - (val < float(0)); diff --git a/src/layers/generic.h b/src/layers/generic.h index 0cca4a38..5f465bd8 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -9,7 +9,7 @@ namespace marian { namespace mlp { -enum struct act : int { linear, tanh, logit, ReLU, swish }; +enum struct act : int { linear, tanh, logit, ReLU, LeakyReLU, PReLU, swish }; } } @@ -128,6 +128,8 @@ public: case act::tanh: return tanh(outputs); case act::logit: return logit(outputs); case act::ReLU: return relu(outputs); + case act::LeakyReLU: return leakyrelu(outputs); + case act::PReLU: return prelu(outputs); case act::swish: return swish(outputs); default: return plus(outputs); } @@ -186,6 +188,8 @@ public: case act::tanh: return tanh(out); case act::logit: return logit(out); case act::ReLU: return relu(out); + case act::LeakyReLU: return leakyrelu(out); + case act::PReLU: return prelu(out); case act::swish: return swish(out); default: return out; } diff --git a/src/tests/attention_tests.cpp b/src/tests/attention_tests.cpp index d6eeae6e..722b82a7 100644 --- a/src/tests/attention_tests.cpp +++ b/src/tests/attention_tests.cpp @@ -94,6 +94,13 @@ TEST_CASE("Model components, Attention", "[attention]") { }); aligned->val()->get(values); + + //for(int i = 0; i < values.size(); ++i) { + // if(i && i % 4 == 0) + // std::cout << std::endl; + // std::cout << values[i] << ", "; + //} + CHECK( std::equal(values.begin(), values.end(), vAligned.begin(), floatApprox) ); } diff --git a/src/training/validator.h b/src/training/validator.h index ec803d5d..9e4fdeec 100644 --- a/src/training/validator.h +++ b/src/training/validator.h @@ -232,7 +232,7 @@ public: fileName = tempFile->getFileName(); } - LOG_VALID(info, "Translating validation set..."); + LOG(info, "Translating validation set..."); graph->setInference(true); boost::timer::cpu_timer timer; @@ -241,6 +241,7 @@ public: auto collector = options_->has("trans-output") ? New<OutputCollector>(fileName) : New<OutputCollector>(*tempFile); + collector->setPrintingStrategy(New<GeometricPrinting>()); size_t sentenceId = 0; @@ -266,7 +267,7 @@ public: } } - LOG_VALID(info, "Total translation time: {}", timer.format(5, "%ws")); + LOG(info, "Total translation time: {}", timer.format(5, "%ws")); graph->setInference(false); float val = 0.0f; diff --git a/src/translator/output_collector.cpp b/src/translator/output_collector.cpp index 9ab17668..86fc4fad 100644 --- a/src/translator/output_collector.cpp +++ b/src/translator/output_collector.cpp @@ -14,7 +14,8 @@ void OutputCollector::Write(long sourceId, bool nbest) { boost::mutex::scoped_lock lock(mutex_); if(sourceId == nextId_) { - LOG(info, "Best translation {} : {}", sourceId, best1); + if(printing_ && printing_->shouldBePrinted(sourceId)) + LOG(info, "Best translation {} : {}", sourceId, best1); if(nbest) ((std::ostream&)*outStrm_) << bestn << std::endl; @@ -31,7 +32,8 @@ void OutputCollector::Write(long sourceId, if(currId == nextId_) { // 1st element in the map is the next const auto& currOutput = iter->second; - LOG(info, "Best translation {} : {}", currId, currOutput.first); + if(printing_ && printing_->shouldBePrinted(sourceId)) + LOG(info, "Best translation {} : {}", currId, currOutput.first); if(nbest) ((std::ostream&)*outStrm_) << currOutput.second << std::endl; else diff --git a/src/translator/output_collector.h b/src/translator/output_collector.h index 479cfe6e..91904297 100644 --- a/src/translator/output_collector.h +++ b/src/translator/output_collector.h @@ -10,6 +10,30 @@ namespace marian { +class PrintingStrategy { +public: + virtual bool shouldBePrinted(long) = 0; +}; + +class GeometricPrinting : public PrintingStrategy { +public: + bool shouldBePrinted(long id) { + if(id == 0) + next_ = start_; + if(id <= 5) + return true; + if(next_ == id) { + next_ += next_; + return true; + } + return false; + } + +private: + size_t start_{10}; + long next_{10}; +}; + class OutputCollector { public: OutputCollector(); @@ -24,6 +48,10 @@ public: const std::string& bestn, bool nbest); + void setPrintingStrategy(Ptr<PrintingStrategy> strategy) { + printing_ = strategy; + } + protected: UPtr<OutputFileStream> outStrm_; boost::mutex mutex_; @@ -31,6 +59,8 @@ protected: typedef std::map<long, std::pair<std::string, std::string>> Outputs; Outputs outputs_; + + Ptr<PrintingStrategy> printing_; }; class StringCollector { |