Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-04 00:57:28 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-04 00:57:28 +0300
commit6a7c9316fcc0e752ff992a1770a430c4b9687b59 (patch)
tree282fd4187c10fb5088492ed608177c77f5547faf /src
parent8a5f319bfb6f0c8dbbc47dac056e19f1c23148fe (diff)
very cool
Diffstat (limited to 'src')
-rwxr-xr-xsrc/a.outbin0 -> 183120 bytes
-rw-r--r--src/mad.h134
-rw-r--r--src/test.cpp23
3 files changed, 157 insertions, 0 deletions
diff --git a/src/a.out b/src/a.out
new file mode 100755
index 00000000..765a2aa9
--- /dev/null
+++ b/src/a.out
Binary files differ
diff --git a/src/mad.h b/src/mad.h
new file mode 100644
index 00000000..0c4b079e
--- /dev/null
+++ b/src/mad.h
@@ -0,0 +1,134 @@
+#pragma once
+
+#include <memory>
+#include <functional>
+#include <vector>
+#include <cmath>
+
+#include <boost/pool/pool.hpp>
+
+namespace mad {
+
+typedef float Tensor;
+
+boost::pool<> p(sizeof(char));
+
+struct Chainable {
+ Chainable() { }
+ virtual ~Chainable() { }
+
+ virtual void chain() { }
+ virtual void init_dependent() { }
+ virtual void set_zero_adjoint() { }
+
+ static inline void* operator new(size_t nbytes) {
+ return p.ordered_malloc(nbytes);
+ }
+};
+
+std::vector<Chainable*> stack;
+
+class Vimpl : public Chainable {
+ public:
+ Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} {
+ stack.push_back(this);
+ }
+
+ ~Vimpl() {};
+
+ virtual void init_dependent() { adj_ = 1; }
+ virtual void set_zero_adjoint() { adj_ = 0; }
+
+ const Tensor& val() const { return val_; };
+ Tensor& adj() { return adj_; };
+
+ protected:
+ const Tensor val_;
+ Tensor adj_;
+};
+
+typedef Vimpl* VimplPtr;
+
+static void set_zero_all_adjoints() {
+ for(auto&& v : stack)
+ v->set_zero_adjoint();
+}
+
+static void grad(Chainable* v) {
+ typedef std::vector<Chainable*>::reverse_iterator It;
+ v->init_dependent();
+ for(It it = stack.rbegin(); it != stack.rend(); ++it) {
+ (*it)->chain();
+ }
+}
+
+class Var {
+ public:
+ Var() : vimpl_{nullptr} {}
+ Var(const Tensor& t) : vimpl_{new Vimpl{t}} {}
+ Var(const VimplPtr& vimpl) : vimpl_{vimpl} {}
+
+ const Tensor& val() const {
+ return vimpl_->val();
+ }
+
+ Tensor& adj() {
+ return vimpl_->adj();
+ }
+
+ VimplPtr vimpl() const {
+ return vimpl_;
+ }
+
+ void grad() {
+ mad::grad(vimpl_);
+ }
+
+ private:
+ VimplPtr vimpl_;
+};
+
+struct OpVimpl : public Vimpl {
+ OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
+
+ VimplPtr a_;
+};
+
+
+struct LogVimpl : public OpVimpl {
+ LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { }
+
+ void chain() {
+ a_->adj() += adj_ / a_->val();
+ }
+};
+
+inline Var log(const Var& a) {
+ return Var(VimplPtr(new LogVimpl(a.vimpl())));
+}
+
+///////////////////////////////////////////////////
+
+
+struct OpVimplVV : public Vimpl {
+ VimplPtr a_;
+ VimplPtr b_;
+
+ OpVimplVV(Tensor t, VimplPtr a, VimplPtr b)
+ : Vimpl(t), a_(a), b_(b) { }
+};
+
+struct PlusVimplVV : public OpVimplVV {
+ PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { }
+
+ void chain() {
+ a_->adj() += adj_;
+ b_->adj() += adj_;
+ }
+};
+
+inline Var operator+(const Var& a, const Var& b) {
+ return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
+}
+
+} \ No newline at end of file
diff --git a/src/test.cpp b/src/test.cpp
new file mode 100644
index 00000000..6f29902c
--- /dev/null
+++ b/src/test.cpp
@@ -0,0 +1,23 @@
+#include <iostream>
+
+#include "mad.h"
+
+int main(int argc, char** argv) {
+
+ using namespace mad;
+ {
+ Var x0 = 1, x1 = 2, x2 = 3;
+
+ auto y = x0 + x0 + log(x2) + x1;
+
+ std::vector<Var> x = { x0, x1, x2 };
+
+
+ set_zero_all_adjoints();
+ y.grad();
+
+ std::cerr << "y = " << y.val() << std::endl;
+ for(int i = 0; i < x.size(); ++i)
+ std::cerr << "dy/dx_" << i << " = " << x[i].adj() << std::endl;
+ }
+} \ No newline at end of file