diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-06-12 18:58:11 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-06-12 18:58:11 +0300 |
commit | b1586590d4fb6bcb113794529adc145dede38ac7 (patch) | |
tree | 70075096938af8cd11dccbe297427757b33f0923 /src/graph/node_operators_unary.h | |
parent | b6f34cd6612f213721ef089c8d9b82aa9ab1b746 (diff) |
integrate memory allocator
Diffstat (limited to 'src/graph/node_operators_unary.h')
-rw-r--r-- | src/graph/node_operators_unary.h | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 81c58b7e..95227269 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -536,14 +536,14 @@ public: Tensor& val() { auto childVal = reshapee_->val(); val_.reset( - new TensorBase(childVal->data(), shape(), childVal->getDevice())); + new TensorBase(childVal->memory(), shape(), childVal->getDevice())); return val_; }; Tensor& grad() { auto childGrad = reshapee_->grad(); adj_.reset( - new TensorBase(childGrad->data(), shape(), childGrad->getDevice())); + new TensorBase(childGrad->memory(), shape(), childGrad->getDevice())); return adj_; }; @@ -592,17 +592,19 @@ public: Tensor& val() { auto childVal = stepNode_->val(); - size_t offset = step_ * shape().elements(); - val_.reset(new TensorBase( - childVal->data() + offset, shape(), childVal->getDevice())); + size_t offset = step_ * shape().elements() * sizeof(float); + auto mem = New<MemoryPiece>(childVal->memory()->data() + offset, + childVal->memory()->size()); + val_.reset(new TensorBase(mem, shape(), childVal->getDevice())); return val_; }; Tensor& grad() { auto childGrad = stepNode_->grad(); - size_t offset = step_ * shape().elements(); - adj_.reset(new TensorBase( - childGrad->data() + offset, shape(), childGrad->getDevice())); + size_t offset = step_ * shape().elements() * sizeof(float); + auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset, + childGrad->memory()->size()); + adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice())); return adj_; }; |