Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 22:04:32 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-10-29 22:04:32 +0300
commit5ffe895d4c6f3561aa1eed0156a06f3333a10bea (patch)
tree7c70acafaf3ed8e309a0256de367b2286fbaca93
parentfe4a804d6692fa3bacd636e8b753f53911c59fbe (diff)
parentb3765f61bdbc30cb5f2a74ac7f882b8a0b9055ba (diff)
adjust rnn to work with new shape
-rw-r--r--src/common/shape.h69
-rw-r--r--src/graph/expression_operators.cu4
-rw-r--r--src/graph/expression_operators.h2
-rw-r--r--src/graph/node_operators_binary.h37
-rw-r--r--src/graph/node_operators_unary.h72
-rw-r--r--src/kernels/tensor_operators.h30
-rw-r--r--src/rnn/cells.h44
-rw-r--r--src/rnn/rnn.h22
-rw-r--r--src/rnn/types.h4
-rw-r--r--src/tests/operator_tests.cpp8
-rw-r--r--src/tests/rnn_tests.cpp5
11 files changed, 156 insertions, 141 deletions
diff --git a/src/common/shape.h b/src/common/shape.h
index b0b88120..d9e4ef9f 100644
--- a/src/common/shape.h
+++ b/src/common/shape.h
@@ -58,7 +58,12 @@ struct Shape {
updateStrides();
}
- inline int& dim(int i) { return shape_[i]; }
+ inline int& dim(int i) {
+ if(i >= 0)
+ return shape_[i];
+ else
+ return shape_[size() + i];
+ }
inline const int& dim(int i) const { return const_cast<Shape&>(*this).dim(i); }
@@ -69,11 +74,17 @@ struct Shape {
inline int& back() { return shape_.back(); }
inline int stride(int i) const {
- return stride_[i];
+ if(i >= 0)
+ return stride_[i];
+ else
+ return stride_[size() + i];
}
inline int bstride(int i) const {
- return bstride_[i];
+ if(i >= 0)
+ return bstride_[i];
+ else
+ return bstride_[size() + i];
}
inline size_t size() const { return shape_.size(); }
@@ -125,6 +136,58 @@ struct Shape {
<< shape.elements() * sizeof(float) << "B)";
return strm;
}
+
+ int axis(int ax) {
+ if(ax < 0)
+ return size() + ax;
+ else
+ return ax;
+ }
+
+ static Shape broadcast(const std::vector<Shape>& shapes) {
+ int maxDims = 0;
+ for(auto& s : shapes)
+ if(s.size() > maxDims)
+ maxDims = s.size();
+
+ Shape shape;
+ shape.resize(maxDims);
+
+ for(auto& s : shapes) {
+ for(int i = 0; i < s.size(); ++i) {
+ ABORT_IF(shape[-i] != s[-i] && shape[-i] != 1 && s[-i] != 1,
+ "Shapes cannot be broadcasted");
+ shape.set(-i, std::max(shape[-i], s[-i]));
+ }
+ }
+ return shape;
+ }
+
+ template <typename T>
+ static Shape broadcast(const std::initializer_list<T>& il) {
+ return broadcast(std::vector<T>(il));
+ }
+
+ template <typename T>
+ static Shape broadcast(const std::vector<T>& nodes) {
+ int maxDims = 0;
+ for(auto& n : nodes)
+ if(n->shape().size() > maxDims)
+ maxDims = n->shape().size();
+
+ Shape shape;
+ shape.resize(maxDims);
+
+ for(auto& node : nodes) {
+ Shape shapen = node->shape();
+ for(int i = 0; i < shapen.size(); ++i) {
+ ABORT_IF(shape[-i] != shapen[-i] && shape[-i] != 1 && shapen[-i] != 1,
+ "Shapes cannot be broadcasted");
+ shape.set(-i, std::max(shape[-i], shapen[-i]));
+ }
+ }
+ return shape;
+ }
};
}
diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu
index 10861c14..90a2047c 100644
--- a/src/graph/expression_operators.cu
+++ b/src/graph/expression_operators.cu
@@ -175,8 +175,8 @@ Expr transpose(Expr a, Shape permute) {
return Expression<TransposeNodeOp>(a, permute);
}
-Expr step(Expr a, size_t step) {
- return Expression<TimestepNodeOp>(a, step);
+Expr step(Expr a, int step, int axis) {
+ return Expression<StepNodeOp>(a, step, axis);
}
Expr cross_entropy(Expr a, Expr b) {
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index c99af41d..2eb150b9 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -95,7 +95,7 @@ Expr scalar_product(Expr a, Expr b, keywords::axis_k ax = 0);
Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax = 0);
-Expr step(Expr a, size_t step);
+Expr step(Expr a, int step, int axis);
Expr sqrt(Expr a, float eps = 0.f);
Expr square(Expr a);
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 2445a6be..98d31640 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -397,18 +397,10 @@ struct ScalarProductNodeOp : public NaryNodeOp {
Shape newShape(Expr a, Expr b, Args... args) {
int ax = keywords::Get(keywords::axis, -1, args...);
- Shape full = a->shape();
- for(int i = 0; i < b->shape().size(); ++i)
- full.set(i, std::max(full[i], b->shape()[i]));
-
- if(ax != -1) {
- full.set(ax, 1);
- } else {
- full.set(0, 1);
- full.set(1, 1);
- full.set(2, 1);
- full.set(3, 1);
- }
+ Shape full = Shape::broadcast({a, b});
+ ax = full.axis(ax);
+
+ full.set(ax, 1);
return full;
}
@@ -432,14 +424,7 @@ struct ElementBinaryNodeOp : public NaryNodeOp {
: NaryNodeOp({a, b}, keywords::shape = newShape(a, b), args...) {}
Shape newShape(Expr a, Expr b) {
- Shape shape1 = a->shape();
- Shape shape2 = b->shape();
- for(int i = 0; i < shape1.size(); ++i) {
- ABORT_IF(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1,
- "Shapes cannot be broadcasted");
- shape1.set(i, std::max(shape1[i], shape2[i]));
- }
- return shape1;
+ return Shape::broadcast({a, b});
}
const std::string color() { return "yellow"; }
@@ -572,15 +557,17 @@ struct ConcatenateNodeOp : public NaryNodeOp {
: NaryNodeOp(nodes,
keywords::shape
= newShape(nodes, keywords::Get(keywords::axis, 0, args...)),
- args...),
- ax_(keywords::Get(keywords::axis, 0, args...)) {}
+ args...) {}
Shape newShape(const std::vector<Expr>& nodes, int ax) {
Shape shape = nodes.back()->shape();
- shape.set(ax, 0);
+ ax_ = shape.axis(ax);
+
+ int sum = 0;
for(auto child : nodes)
- shape.set(ax, shape[ax] + child->shape()[ax]);
- // std::cerr << ax << " : " << shape[0] << " " << shape[1] << std::endl;
+ sum += child->shape()[ax_];
+ shape.set(ax_, sum);
+
return shape;
}
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index a3f60366..08ecde46 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -141,17 +141,7 @@ struct TanhNodeOp : public NaryNodeOp {
: NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {}
Shape newShape(const std::vector<Expr>& nodes) {
- Shape shape = nodes[0]->shape();
-
- for(int n = 1; n < nodes.size(); ++n) {
- Shape shapen = nodes[n]->shape();
- for(int i = 0; i < shapen.size(); ++i) {
- ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1,
- "Shapes cannot be broadcasted");
- shape.set(i, std::max(shape[i], shapen[i]));
- }
- }
- return shape;
+ return Shape::broadcast(nodes);
}
NodeOps forwardOps() {
@@ -325,8 +315,7 @@ struct SumNodeOp : public UnaryNodeOp {
template <typename... Args>
SumNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) {}
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; }
@@ -334,15 +323,10 @@ struct SumNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Args... args) {
- int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
- if(ax != -1) {
- shape.set(ax, 1);
- } else {
- for(int i = 0; i < shape.size(); ++i) {
- shape.set(i, 1);
- }
- }
+ ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...));
+
+ shape.set(ax_, 1);
return shape;
}
@@ -375,8 +359,7 @@ struct MeanNodeOp : public UnaryNodeOp {
template <typename... Args>
MeanNodeOp(Expr a, Args... args)
- : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...),
- ax_(keywords::Get(keywords::axis, -1, args...)) {}
+ : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() {
int left = child(0)->shape().elements() / val_->shape().elements();
@@ -394,15 +377,9 @@ struct MeanNodeOp : public UnaryNodeOp {
template <class... Args>
Shape newShape(Expr a, Args... args) {
- int ax = keywords::Get(keywords::axis, -1, args...);
Shape shape = a->shape();
- if(ax != -1) {
- shape.set(ax, 1);
- } else {
- for(int i = 0; i < shape.size(); ++i) {
- shape.set(i, 1);
- }
- }
+ ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...));
+ shape.set(ax_, 1);
return shape;
}
@@ -637,8 +614,7 @@ struct ColsNodeOp : public UnaryNodeOp {
struct SelectNodeOp : public UnaryNodeOp {
SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces)
: UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)),
- indeces_(indeces),
- axis_(axis) {}
+ indeces_(indeces) {}
NodeOps forwardOps() {
return {NodeOp(
@@ -652,7 +628,8 @@ struct SelectNodeOp : public UnaryNodeOp {
Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) {
Shape shape = a->shape();
- shape.set(axis, indeces.size());
+ axis_ = shape.axis(axis);
+ shape.set(axis_, indeces.size());
return shape;
}
@@ -707,8 +684,8 @@ struct TransposeNodeOp : public UnaryNodeOp {
Shape newShape(Expr a, Shape permute) {
Shape shape = a->shape();
- UTIL_THROW_IF2(shape.size() != permute.size(),
- "Shape and transpose axis have different number of dimensions");
+ ABORT_IF(shape.size() != permute.size(),
+ "Shape and transpose axis have different number of dimensions");
for(int i = 0; i < shape.size(); ++i)
shape.set(i, a->shape()[permute[i]]);
@@ -806,23 +783,27 @@ public:
}
};
-class TimestepNodeOp : public UnaryNodeOp {
+class StepNodeOp : public UnaryNodeOp {
private:
Expr stepNode_;
- size_t step_;
+ int step_;
+ int axis_;
public:
- TimestepNodeOp(Expr a, size_t step)
- : UnaryNodeOp(a, keywords::shape = newShape(a)),
+ StepNodeOp(Expr a, int step, int axis)
+ : UnaryNodeOp(a, keywords::shape = newShape(a, axis)),
stepNode_(a),
step_(step) {
Node::destroy_ = false;
}
- Shape newShape(Expr a) {
+ Shape newShape(Expr a, int axis) {
Shape outShape = a->shape();
- outShape.set(2, 1);
- outShape.set(3, 1);
+
+ axis_ = outShape.axis(axis);
+ for(int i = 0; i <= axis_; ++i)
+ outShape.set(i, 1);
+
return outShape;
}
@@ -862,6 +843,7 @@ public:
if(!hash_) {
hash_ = NaryNodeOp::hash();
boost::hash_combine(hash_, step_);
+ boost::hash_combine(hash_, axis_);
}
return hash_;
}
@@ -869,11 +851,13 @@ public:
virtual bool equal(Expr node) {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<TimestepNodeOp> cnode = std::dynamic_pointer_cast<TimestepNodeOp>(node);
+ Ptr<StepNodeOp> cnode = std::dynamic_pointer_cast<StepNodeOp>(node);
if(!cnode)
return false;
if(step_ != cnode->step_)
return false;
+ if(axis_ != cnode->axis_)
+ return false;
return true;
}
};
diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h
index e178eeac..707d5b95 100644
--- a/src/kernels/tensor_operators.h
+++ b/src/kernels/tensor_operators.h
@@ -185,9 +185,7 @@ void Add(Functor functor, Tensor out, Tensor in, float scale = 1.0) {
UTIL_THROW_IF2(out->shape().size() != in->shape().size(),
"Number of dimensions does not match");
- auto full = out->shape();
- for(int i = 0; i < in->shape().size(); ++i)
- full.set(i, std::max(full[i], in->shape()[i]));
+ auto full = Shape::broadcast({out, in});
int length = out->shape().elements();
@@ -396,16 +394,7 @@ void Add(Functor functor,
float scale = 1.0) {
cudaSetDevice(out->getDevice());
- UTIL_THROW_IF2(out->shape().size() != in1->shape().size(),
- "Number of dimensions does not match");
- UTIL_THROW_IF2(out->shape().size() != in2->shape().size(),
- "Number of dimensions does not match");
-
- auto full = out->shape();
- for(int i = 0; i < in1->shape().size(); ++i)
- full.set(i, std::max(full[i], in1->shape()[i]));
- for(int i = 0; i < in2->shape().size(); ++i)
- full.set(i, std::max(full[i], in2->shape()[i]));
+ Shape full = Shape::broadcast({out, in1, in2});
int length = out->shape().elements();
@@ -628,20 +617,7 @@ template <class Functor>
void Add(Functor functor, Tensor out, Tensor in1, Tensor in2, Tensor in3) {
cudaSetDevice(out->getDevice());
- UTIL_THROW_IF2(out->shape().size() != in1->shape().size(),
- "Number of dimensions does not match");
- UTIL_THROW_IF2(out->shape().size() != in2->shape().size(),
- "Number of dimensions does not match");
- UTIL_THROW_IF2(out->shape().size() != in3->shape().size(),
- "Number of dimensions does not match");
-
- auto full = out->shape();
- for(int i = 0; i < in1->shape().size(); ++i)
- full.set(i, std::max(full[i], in1->shape()[i]));
- for(int i = 0; i < in2->shape().size(); ++i)
- full.set(i, std::max(full[i], in2->shape()[i]));
- for(int i = 0; i < in3->shape().size(); ++i)
- full.set(i, std::max(full[i], in3->shape()[i]));
+ Shape full = Shape::broadcast({out, in1, in2, in3});
int length = out->shape().elements();
diff --git a/src/rnn/cells.h b/src/rnn/cells.h
index 057a7c3c..cea020fa 100644
--- a/src/rnn/cells.h
+++ b/src/rnn/cells.h
@@ -75,7 +75,7 @@ public:
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ input = concatenate(inputs, keywords::axis = -1);
else
input = inputs.front();
@@ -150,7 +150,7 @@ public:
auto Ux = graph->param(prefix + "_Ux",
{dimState, dimState},
keywords::init = inits::glorot_uniform);
- U_ = concatenate({U, Ux}, keywords::axis = 1);
+ U_ = concatenate({U, Ux}, keywords::axis = -1);
if(dimInput > 0) {
auto W = graph->param(prefix + "_W",
@@ -159,14 +159,14 @@ public:
auto Wx = graph->param(prefix + "_Wx",
{dimInput, dimState},
keywords::init = inits::glorot_uniform);
- W_ = concatenate({W, Wx}, keywords::axis = 1);
+ W_ = concatenate({W, Wx}, keywords::axis = -1);
}
auto b = graph->param(
prefix + "_b", {1, 2 * dimState}, keywords::init = inits::zeros);
auto bx = graph->param(
prefix + "_bx", {1, dimState}, keywords::init = inits::zeros);
- b_ = concatenate({b, bx}, keywords::axis = 1);
+ b_ = concatenate({b, bx}, keywords::axis = -1);
// @TODO use this and adjust Amun model type saving and loading
// U_ = graph->param(prefix + "_U", {dimState, 3 * dimState},
@@ -204,7 +204,7 @@ public:
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ input = concatenate(inputs, keywords::axis = -1);
else
input = inputs[0];
@@ -311,7 +311,7 @@ public:
U_ = U;
Ux_ = Ux;
} else {
- UUx_ = concatenate({U, Ux}, keywords::axis = 1);
+ UUx_ = concatenate({U, Ux}, keywords::axis = -1);
}
if(dimInput > 0) {
@@ -325,7 +325,7 @@ public:
W_ = W;
Wx_ = Wx;
} else {
- WWx_ = concatenate({W, Wx}, keywords::axis = 1);
+ WWx_ = concatenate({W, Wx}, keywords::axis = -1);
}
}
@@ -342,13 +342,13 @@ public:
if(encoder_ && transition_) {
auto b0
= graph->constant({1, 2 * dimState}, keywords::init = inits::zeros);
- bbx_ = concatenate({b0, bx}, keywords::axis = 1);
+ bbx_ = concatenate({b0, bx}, keywords::axis = -1);
} else {
bbx_
= graph->constant({1, 3 * dimState}, keywords::init = inits::zeros);
}
} else {
- bbx_ = concatenate({b, bx}, keywords::axis = 1);
+ bbx_ = concatenate({b, bx}, keywords::axis = -1);
}
if(dropout_ > 0.0f) {
@@ -396,7 +396,7 @@ public:
if(inputs.size() == 0)
return {};
else if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ input = concatenate(inputs, keywords::axis = -1);
else
input = inputs[0];
@@ -418,7 +418,7 @@ public:
W = layer_norm(W, W_lns_, W_lnb_, NEMATUS_LN_EPS);
Wx = layer_norm(Wx, Wx_lns_, Wx_lnb_, NEMATUS_LN_EPS);
- xW = concatenate({W, Wx}, keywords::axis = 1);
+ xW = concatenate({W, Wx}, keywords::axis = -1);
} else {
xW = dot(input, WWx_);
}
@@ -462,7 +462,7 @@ public:
Ux = layer_norm(Ux, Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS);
}
- sU = concatenate({U, Ux}, keywords::axis = 1);
+ sU = concatenate({U, Ux}, keywords::axis = -1);
} else {
sU = dot(stateDropped, UUx_);
}
@@ -553,8 +553,9 @@ public:
Expr input;
if(inputs.size() == 0)
return {};
- else if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ else if(inputs.size() > 1) {
+ input = concatenate(inputs, keywords::axis = -1);
+ }
else
input = inputs.front();
@@ -646,8 +647,9 @@ public:
ABORT_IF(inputs.empty(), "Multiplicative LSTM expects input");
Expr input;
- if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ if(inputs.size() > 1) {
+ input = concatenate(inputs, keywords::axis = -1);
+ }
else
input = inputs.front();
@@ -742,7 +744,7 @@ public:
Expr input;
if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ input = concatenate(inputs, keywords::axis = -1);
else
input = inputs.front();
@@ -826,9 +828,9 @@ public:
auto bo = graph->param(
prefix + "_bo", {1, dimState}, keywords::init = inits::zeros);
- U_ = concatenate({Uf, Ui, Uc, Uo}, keywords::axis = 1);
- W_ = concatenate({Wf, Wi, Wc, Wo}, keywords::axis = 1);
- b_ = concatenate({bf, bi, bc, bo}, keywords::axis = 1);
+ U_ = concatenate({Uf, Ui, Uc, Uo}, keywords::axis = -1);
+ W_ = concatenate({Wf, Wi, Wc, Wo}, keywords::axis = -1);
+ b_ = concatenate({bf, bi, bc, bo}, keywords::axis = -1);
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
@@ -840,7 +842,7 @@ public:
Expr input;
if(inputs.size() > 1)
- input = concatenate(inputs, keywords::axis = 1);
+ input = concatenate(inputs, keywords::axis = -1);
else
input = inputs.front();
diff --git a/src/rnn/rnn.h b/src/rnn/rnn.h
index 67947db6..62665800 100644
--- a/src/rnn/rnn.h
+++ b/src/rnn/rnn.h
@@ -68,7 +68,7 @@ private:
auto xWs = cell_->applyInput({input});
- size_t timeSteps = input->shape()[2];
+ size_t timeSteps = input->shape()[-3];
States outputs;
for(size_t i = 0; i < timeSteps; ++i) {
@@ -78,12 +78,13 @@ private:
j = timeSteps - i - 1;
std::vector<Expr> steps(xWs.size());
- std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) {
- return step(e, j);
- });
+ std::transform(xWs.begin(),
+ xWs.end(),
+ steps.begin(),
+ [j](Expr e) { return step(e, j, -3); });
if(mask)
- state = cell_->applyState(steps, state, step(mask, j));
+ state = cell_->applyState(steps, state, step(mask, j, -3));
else
state = cell_->applyState(steps, state);
@@ -100,10 +101,11 @@ private:
States apply(const Expr input, const Expr mask = nullptr) {
auto graph = input->graph();
- int dimBatch = input->shape()[0];
+
+ int dimBatch = input->shape()[-2];
int dimState = cell_->getOptions()->get<int>("dimState");
- auto output = graph->zeros(keywords::shape = {dimBatch, dimState});
+ auto output = graph->zeros(keywords::shape = {1, dimBatch, dimState});
Expr cell = output;
State startState{output, cell};
@@ -171,7 +173,7 @@ public:
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
- lazyInput = concatenate(lazyInputs, keywords::axis = 1);
+ lazyInput = concatenate(lazyInputs, keywords::axis = -1);
}
auto layerOutput = rnns_[i]->transduce(lazyInput, mask);
@@ -197,7 +199,7 @@ public:
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
- lazyInput = concatenate(lazyInputs, keywords::axis = 1);
+ lazyInput = concatenate(lazyInputs, keywords::axis = -1);
} else {
lazyInput = layerInput;
}
@@ -227,7 +229,7 @@ public:
auto lazyInputs = cell->getLazyInputs(shared_from_this());
if(!lazyInputs.empty()) {
lazyInputs.push_back(layerInput);
- lazyInput = concatenate(lazyInputs, keywords::axis = 1);
+ lazyInput = concatenate(lazyInputs, keywords::axis = -1);
}
auto layerOutput = rnns_[i]->transduce(lazyInput, States({state}), mask);
diff --git a/src/rnn/types.h b/src/rnn/types.h
index 9e288d5a..58d4aff2 100644
--- a/src/rnn/types.h
+++ b/src/rnn/types.h
@@ -48,7 +48,7 @@ public:
for(auto s : states_)
outputs.push_back(s.output);
if(outputs.size() > 1)
- return concatenate(outputs, keywords::axis = 2);
+ return concatenate(outputs, keywords::axis = -3);
else
return outputs[0];
}
@@ -172,7 +172,7 @@ public:
outputs.push_back(input->apply(state));
if(outputs.size() > 1)
- return concatenate(outputs, keywords::axis = 1);
+ return concatenate(outputs, keywords::axis = -1);
else
return outputs[0];
}
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp
index 4e7dcdd0..ffec269b 100644
--- a/src/tests/operator_tests.cpp
+++ b/src/tests/operator_tests.cpp
@@ -90,7 +90,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
std::vector<float> vDiv({2.0f, -1.33333f, 6.0f, -2.66667f});
auto a = graph->constant({2, 2, 1}, keywords::init = inits::from_vector(vA));
- auto b = graph->constant({1, 2, 1}, keywords::init = inits::from_vector(vB));
+ auto b = graph->constant({2, 1}, keywords::init = inits::from_vector(vB));
auto add = a + b;
auto minus = b - a;
@@ -183,7 +183,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
auto sp = scalar_product(s2, s2, keywords::axis=0);
- auto wa = weighted_average(a, s1, keywords::axis=1);
+ auto wa = weighted_average(a, s1, keywords::axis=-1);
graph->forward();
@@ -237,8 +237,8 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") {
auto in4 = graph->constant({1, 2, 2, 3}, keywords::init=inits::from_value(4));
auto c1out1 = concatenate({in1, in2, in3, in4}, keywords::axis=2);
- auto c1out2 = concatenate({in1, in2, in3, in4}, keywords::axis=3);
- auto c1out3 = concatenate({in1, in2, in3, in4}, keywords::axis=1);
+ auto c1out2 = concatenate({in1, in2, in3, in4}, keywords::axis=-1);
+ auto c1out3 = concatenate({in1, in2, in3, in4}, keywords::axis=-3);
auto c1out4 = concatenate({in1, in2, in3, in4}, keywords::axis=0);
graph->forward();
diff --git a/src/tests/rnn_tests.cpp b/src/tests/rnn_tests.cpp
index 5e763f29..14bd3805 100644
--- a/src/tests/rnn_tests.cpp
+++ b/src/tests/rnn_tests.cpp
@@ -86,7 +86,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") {
bool layerNorm=false,
bool skip=false) {
- int dimEmb = input->shape()[1];
+ int dimEmb = input->shape()[-1];
int first, second;
if(type == "bidirectional" || type == "alternating") {
@@ -154,7 +154,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") {
auto context = concatenate({rnnFw->transduce(input, mask),
rnnBw->transduce(input, mask)},
- axis = 1);
+ axis = input->shape().size() - 1);
if(second > 0) {
// add more layers (unidirectional) by transducing the output of the
@@ -227,6 +227,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") {
});
contextSum1->val()->get(values);
+
CHECK( std::equal(values.begin(), values.end(),
vContextSum1.begin(), floatApprox) );