Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-09 12:55:24 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2016-05-09 12:55:24 +0300
commit4b79f3e72adc61b0d99a6b76c5dcad8210b8ceae (patch)
tree7441a308bb5ad135a8a622bff9cd477f3d13244e /src/marian.h
parent3db0f3031214a217bd5d7235f6692178998a2355 (diff)
towards lazy allocationi
Diffstat (limited to 'src/marian.h')
-rw-r--r--src/marian.h120
1 files changed, 7 insertions, 113 deletions
diff --git a/src/marian.h b/src/marian.h
index b8320d91..18510943 100644
--- a/src/marian.h
+++ b/src/marian.h
@@ -1,115 +1,9 @@
#pragma once
-#include <memory>
-#include <functional>
-#include <vector>
-#include <cmath>
-
-#include "exception.h"
-#include "cudnn_tensor.h"
-
-namespace marian {
-
-template <class DataType>
-struct Chainable : public std::enable_shared_from_this<Chainable<DataType>> {
- Chainable() { }
- virtual ~Chainable() { }
- virtual void forward() { }
- virtual void backward() { }
- virtual void init_dependent() { }
- virtual void set_zero_adjoint() { }
-
- virtual DataType val() = 0;
- virtual DataType grad() = 0;
-};
-
-typedef std::vector<Chainable<Tensor>*> ChainableStack;
-typedef std::shared_ptr<Chainable<Tensor>> ChainPtr;
-
-ChainableStack stack;
-
-class Node : public Chainable<Tensor> {
- public:
- Node(const Tensor t) : val_(t) {
- //std::cerr << "Putting node with tensor " << t.id() << " on stack" << std::endl;
- stack.push_back(this);
- }
-
- virtual ~Node() {};
-
- virtual void init_dependent() {
- if(adj_) {
- adj_.set(1);
- }
- else {
- adj_ = Tensor(val_.shape(), 1);
- }
- }
-
- virtual void set_zero_adjoint() {
- if(adj_) {
- adj_.set(0);
- }
- else {
- adj_ = Tensor(val_.shape(), 0);
- }
- }
-
- virtual Tensor val() { return val_; };
- virtual Tensor grad() { return adj_; };
-
- protected:
- Tensor val_;
- Tensor adj_;
-};
-
-class Var {
- public:
- Var() : pimpl_(nullptr) {}
- Var(const Tensor t) : pimpl_(new Node(t)) {}
- Var(const Tensor::value_type v) : pimpl_(new Node(Tensor(v))) {}
- Var(const ChainPtr chainable) : pimpl_(chainable) {}
- Var(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
-
- Tensor val() {
- return pimpl_->val();
- }
-
- Tensor grad() {
- return pimpl_->grad();
- }
-
- ChainPtr pimpl() {
- return pimpl_;
- }
-
- void forward() {
- UTIL_THROW_IF2(pimpl_.get() != stack.back(),
- "Trying to call forward on non-root of computation graph");
-
- for(auto&& v : stack)
- v->forward();
- }
-
- void backward() {
- UTIL_THROW_IF2(pimpl_.get() != stack.back(),
- "Trying to call backward on non-root of computation graph");
-
- for(auto&& v : stack)
- v->set_zero_adjoint();
-
- typedef ChainableStack::reverse_iterator It;
- pimpl_->init_dependent();
- for(It it = stack.rbegin(); it != stack.rend(); ++it)
- (*it)->backward();
- }
-
- operator ChainPtr() {
- return pimpl_;
- }
-
- private:
- ChainPtr pimpl_;
-};
-
-} \ No newline at end of file
+#include "definitions.h"
+#include "graph.h"
+#include "graph_operators.h"
+#include "expressions.h"
+#include "expression_operators.h"
+//#include "tensor.h"
+//#include "tensor_operators.h"