Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-04 13:07:44 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-04 13:07:44 +0300
commit28ba2628f9e8a684113a53e1eabb05d0a190fffb (patch)
treebf7fa439d2d20525480d1ff701d211741f66eaba /src
parentb9e26509dd4ce1d2e29ef0bb400d2adb984e8673 (diff)
more operators
Diffstat (limited to 'src')
-rw-r--r--src/marian.h80
-rw-r--r--src/test.cpp2
2 files changed, 81 insertions, 1 deletions
diff --git a/src/marian.h b/src/marian.h
index 4a14e340..18d561c4 100644
--- a/src/marian.h
+++ b/src/marian.h
@@ -89,6 +89,8 @@ class Var {
VimplPtr vimpl_;
};
+///////////////////////////////////////////////////
+
struct OpVimpl : public Vimpl {
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
@@ -108,6 +110,44 @@ inline Var log(const Var& a) {
return Var(VimplPtr(new LogVimpl(a.vimpl())));
}
+struct ExpVimpl : public OpVimpl {
+ ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
+
+ void chain() {
+ a_->grad() += adj_ * std::exp(a_->val());
+ }
+};
+
+inline Var exp(const Var& a) {
+ return Var(VimplPtr(new ExpVimpl(a.vimpl())));
+}
+
+struct NegVimpl : public OpVimpl {
+ NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
+
+ void chain() {
+ a_->grad() -= adj_;
+ }
+};
+
+inline Var operator-(const Var& a) {
+ return Var(VimplPtr(new NegVimpl(a.vimpl())));
+}
+
+// @TODO: take care of large exponents
+struct SigmaVimpl : public OpVimpl {
+ SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
+
+ void chain() {
+ Tensor l = 1.f / (1.f + std::exp(-a_->val()));
+ a_->grad() += adj_ * l * (1 - l);
+ }
+};
+
+inline Var sigma(const Var& a) {
+ return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
+}
+
///////////////////////////////////////////////////
@@ -132,4 +172,44 @@ inline Var operator+(const Var& a, const Var& b) {
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
}
+struct MinusVimplVV : public OpVimplVV {
+ MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
+
+ void chain() {
+ a_->grad() -= adj_;
+ b_->grad() -= adj_;
+ }
+};
+
+inline Var operator-(const Var& a, const Var& b) {
+ return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
+}
+
+struct MultVimplVV : public OpVimplVV {
+ MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
+
+ void chain() {
+ a_->grad() += adj_ * b_->val();
+ b_->grad() += adj_ * a_->val();
+ }
+};
+
+inline Var operator*(const Var& a, const Var& b) {
+ return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
+}
+
+struct DivVimplVV : public OpVimplVV {
+ DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
+
+ void chain() {
+ a_->grad() += adj_ / b_->val();
+ b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
+ }
+};
+
+inline Var operator/(const Var& a, const Var& b) {
+ return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
+}
+
+
} \ No newline at end of file
diff --git a/src/test.cpp b/src/test.cpp
index 728fb3ef..795d845d 100644
--- a/src/test.cpp
+++ b/src/test.cpp
@@ -29,7 +29,7 @@ int main(int argc, char** argv) {
Var y1 = layer(10, x1);
Var y2 = layer(rand() % 20 + 1, x2);
- Var y = y1 + log(y2);
+ Var y = sigma(log(y1) / log(y2));
set_zero_all_adjoints();
y.calc_gradients();