diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2016-05-09 12:55:24 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2016-05-09 12:55:24 +0300 |
commit | 4b79f3e72adc61b0d99a6b76c5dcad8210b8ceae (patch) | |
tree | 7441a308bb5ad135a8a622bff9cd477f3d13244e /src/marian.h | |
parent | 3db0f3031214a217bd5d7235f6692178998a2355 (diff) |
towards lazy allocationi
Diffstat (limited to 'src/marian.h')
-rw-r--r-- | src/marian.h | 120 |
1 files changed, 7 insertions, 113 deletions
diff --git a/src/marian.h b/src/marian.h index b8320d91..18510943 100644 --- a/src/marian.h +++ b/src/marian.h @@ -1,115 +1,9 @@ #pragma once -#include <memory> -#include <functional> -#include <vector> -#include <cmath> - -#include "exception.h" -#include "cudnn_tensor.h" - -namespace marian { - -template <class DataType> -struct Chainable : public std::enable_shared_from_this<Chainable<DataType>> { - Chainable() { } - virtual ~Chainable() { } - virtual void forward() { } - virtual void backward() { } - virtual void init_dependent() { } - virtual void set_zero_adjoint() { } - - virtual DataType val() = 0; - virtual DataType grad() = 0; -}; - -typedef std::vector<Chainable<Tensor>*> ChainableStack; -typedef std::shared_ptr<Chainable<Tensor>> ChainPtr; - -ChainableStack stack; - -class Node : public Chainable<Tensor> { - public: - Node(const Tensor t) : val_(t) { - //std::cerr << "Putting node with tensor " << t.id() << " on stack" << std::endl; - stack.push_back(this); - } - - virtual ~Node() {}; - - virtual void init_dependent() { - if(adj_) { - adj_.set(1); - } - else { - adj_ = Tensor(val_.shape(), 1); - } - } - - virtual void set_zero_adjoint() { - if(adj_) { - adj_.set(0); - } - else { - adj_ = Tensor(val_.shape(), 0); - } - } - - virtual Tensor val() { return val_; }; - virtual Tensor grad() { return adj_; }; - - protected: - Tensor val_; - Tensor adj_; -}; - -class Var { - public: - Var() : pimpl_(nullptr) {} - Var(const Tensor t) : pimpl_(new Node(t)) {} - Var(const Tensor::value_type v) : pimpl_(new Node(Tensor(v))) {} - Var(const ChainPtr chainable) : pimpl_(chainable) {} - Var(Chainable<Tensor>* chainable) : pimpl_(chainable) {} - - Tensor val() { - return pimpl_->val(); - } - - Tensor grad() { - return pimpl_->grad(); - } - - ChainPtr pimpl() { - return pimpl_; - } - - void forward() { - UTIL_THROW_IF2(pimpl_.get() != stack.back(), - "Trying to call forward on non-root of computation graph"); - - for(auto&& v : stack) - v->forward(); - } - - void backward() { - UTIL_THROW_IF2(pimpl_.get() != stack.back(), - "Trying to call backward on non-root of computation graph"); - - for(auto&& v : stack) - v->set_zero_adjoint(); - - typedef ChainableStack::reverse_iterator It; - pimpl_->init_dependent(); - for(It it = stack.rbegin(); it != stack.rend(); ++it) - (*it)->backward(); - } - - operator ChainPtr() { - return pimpl_; - } - - private: - ChainPtr pimpl_; -}; - -}
\ No newline at end of file +#include "definitions.h" +#include "graph.h" +#include "graph_operators.h" +#include "expressions.h" +#include "expression_operators.h" +//#include "tensor.h" +//#include "tensor_operators.h" |