diff options
-rw-r--r-- | src/common/shape.h | 69 | ||||
-rw-r--r-- | src/graph/expression_operators.cu | 4 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 37 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 72 | ||||
-rw-r--r-- | src/kernels/tensor_operators.h | 30 | ||||
-rw-r--r-- | src/rnn/cells.h | 44 | ||||
-rw-r--r-- | src/rnn/rnn.h | 22 | ||||
-rw-r--r-- | src/rnn/types.h | 4 | ||||
-rw-r--r-- | src/tests/operator_tests.cpp | 8 | ||||
-rw-r--r-- | src/tests/rnn_tests.cpp | 5 |
11 files changed, 156 insertions, 141 deletions
diff --git a/src/common/shape.h b/src/common/shape.h index b0b88120..d9e4ef9f 100644 --- a/src/common/shape.h +++ b/src/common/shape.h @@ -58,7 +58,12 @@ struct Shape { updateStrides(); } - inline int& dim(int i) { return shape_[i]; } + inline int& dim(int i) { + if(i >= 0) + return shape_[i]; + else + return shape_[size() + i]; + } inline const int& dim(int i) const { return const_cast<Shape&>(*this).dim(i); } @@ -69,11 +74,17 @@ struct Shape { inline int& back() { return shape_.back(); } inline int stride(int i) const { - return stride_[i]; + if(i >= 0) + return stride_[i]; + else + return stride_[size() + i]; } inline int bstride(int i) const { - return bstride_[i]; + if(i >= 0) + return bstride_[i]; + else + return bstride_[size() + i]; } inline size_t size() const { return shape_.size(); } @@ -125,6 +136,58 @@ struct Shape { << shape.elements() * sizeof(float) << "B)"; return strm; } + + int axis(int ax) { + if(ax < 0) + return size() + ax; + else + return ax; + } + + static Shape broadcast(const std::vector<Shape>& shapes) { + int maxDims = 0; + for(auto& s : shapes) + if(s.size() > maxDims) + maxDims = s.size(); + + Shape shape; + shape.resize(maxDims); + + for(auto& s : shapes) { + for(int i = 0; i < s.size(); ++i) { + ABORT_IF(shape[-i] != s[-i] && shape[-i] != 1 && s[-i] != 1, + "Shapes cannot be broadcasted"); + shape.set(-i, std::max(shape[-i], s[-i])); + } + } + return shape; + } + + template <typename T> + static Shape broadcast(const std::initializer_list<T>& il) { + return broadcast(std::vector<T>(il)); + } + + template <typename T> + static Shape broadcast(const std::vector<T>& nodes) { + int maxDims = 0; + for(auto& n : nodes) + if(n->shape().size() > maxDims) + maxDims = n->shape().size(); + + Shape shape; + shape.resize(maxDims); + + for(auto& node : nodes) { + Shape shapen = node->shape(); + for(int i = 0; i < shapen.size(); ++i) { + ABORT_IF(shape[-i] != shapen[-i] && shape[-i] != 1 && shapen[-i] != 1, + "Shapes cannot be broadcasted"); + shape.set(-i, std::max(shape[-i], shapen[-i])); + } + } + return shape; + } }; } diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 10861c14..90a2047c 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -175,8 +175,8 @@ Expr transpose(Expr a, Shape permute) { return Expression<TransposeNodeOp>(a, permute); } -Expr step(Expr a, size_t step) { - return Expression<TimestepNodeOp>(a, step); +Expr step(Expr a, int step, int axis) { + return Expression<StepNodeOp>(a, step, axis); } Expr cross_entropy(Expr a, Expr b) { diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index c99af41d..2eb150b9 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -95,7 +95,7 @@ Expr scalar_product(Expr a, Expr b, keywords::axis_k ax = 0); Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax = 0); -Expr step(Expr a, size_t step); +Expr step(Expr a, int step, int axis); Expr sqrt(Expr a, float eps = 0.f); Expr square(Expr a); diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 2445a6be..98d31640 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -397,18 +397,10 @@ struct ScalarProductNodeOp : public NaryNodeOp { Shape newShape(Expr a, Expr b, Args... args) { int ax = keywords::Get(keywords::axis, -1, args...); - Shape full = a->shape(); - for(int i = 0; i < b->shape().size(); ++i) - full.set(i, std::max(full[i], b->shape()[i])); - - if(ax != -1) { - full.set(ax, 1); - } else { - full.set(0, 1); - full.set(1, 1); - full.set(2, 1); - full.set(3, 1); - } + Shape full = Shape::broadcast({a, b}); + ax = full.axis(ax); + + full.set(ax, 1); return full; } @@ -432,14 +424,7 @@ struct ElementBinaryNodeOp : public NaryNodeOp { : NaryNodeOp({a, b}, keywords::shape = newShape(a, b), args...) {} Shape newShape(Expr a, Expr b) { - Shape shape1 = a->shape(); - Shape shape2 = b->shape(); - for(int i = 0; i < shape1.size(); ++i) { - ABORT_IF(shape1[i] != shape2[i] && shape1[i] != 1 && shape2[i] != 1, - "Shapes cannot be broadcasted"); - shape1.set(i, std::max(shape1[i], shape2[i])); - } - return shape1; + return Shape::broadcast({a, b}); } const std::string color() { return "yellow"; } @@ -572,15 +557,17 @@ struct ConcatenateNodeOp : public NaryNodeOp { : NaryNodeOp(nodes, keywords::shape = newShape(nodes, keywords::Get(keywords::axis, 0, args...)), - args...), - ax_(keywords::Get(keywords::axis, 0, args...)) {} + args...) {} Shape newShape(const std::vector<Expr>& nodes, int ax) { Shape shape = nodes.back()->shape(); - shape.set(ax, 0); + ax_ = shape.axis(ax); + + int sum = 0; for(auto child : nodes) - shape.set(ax, shape[ax] + child->shape()[ax]); - // std::cerr << ax << " : " << shape[0] << " " << shape[1] << std::endl; + sum += child->shape()[ax_]; + shape.set(ax_, sum); + return shape; } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f60366..08ecde46 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -141,17 +141,7 @@ struct TanhNodeOp : public NaryNodeOp { : NaryNodeOp(nodes, keywords::shape = newShape(nodes)) {} Shape newShape(const std::vector<Expr>& nodes) { - Shape shape = nodes[0]->shape(); - - for(int n = 1; n < nodes.size(); ++n) { - Shape shapen = nodes[n]->shape(); - for(int i = 0; i < shapen.size(); ++i) { - ABORT_IF(shape[i] != shapen[i] && shape[i] != 1 && shapen[i] != 1, - "Shapes cannot be broadcasted"); - shape.set(i, std::max(shape[i], shapen[i])); - } - } - return shape; + return Shape::broadcast(nodes); } NodeOps forwardOps() { @@ -325,8 +315,7 @@ struct SumNodeOp : public UnaryNodeOp { template <typename... Args> SumNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...), - ax_(keywords::Get(keywords::axis, -1, args...)) {} + : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } @@ -334,15 +323,10 @@ struct SumNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Args... args) { - int ax = keywords::Get(keywords::axis, -1, args...); Shape shape = a->shape(); - if(ax != -1) { - shape.set(ax, 1); - } else { - for(int i = 0; i < shape.size(); ++i) { - shape.set(i, 1); - } - } + ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...)); + + shape.set(ax_, 1); return shape; } @@ -375,8 +359,7 @@ struct MeanNodeOp : public UnaryNodeOp { template <typename... Args> MeanNodeOp(Expr a, Args... args) - : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...), - ax_(keywords::Get(keywords::axis, -1, args...)) {} + : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { int left = child(0)->shape().elements() / val_->shape().elements(); @@ -394,15 +377,9 @@ struct MeanNodeOp : public UnaryNodeOp { template <class... Args> Shape newShape(Expr a, Args... args) { - int ax = keywords::Get(keywords::axis, -1, args...); Shape shape = a->shape(); - if(ax != -1) { - shape.set(ax, 1); - } else { - for(int i = 0; i < shape.size(); ++i) { - shape.set(i, 1); - } - } + ax_ = shape.axis(keywords::Get(keywords::axis, -1, args...)); + shape.set(ax_, 1); return shape; } @@ -637,8 +614,7 @@ struct ColsNodeOp : public UnaryNodeOp { struct SelectNodeOp : public UnaryNodeOp { SelectNodeOp(Expr a, int axis, const std::vector<size_t>& indeces) : UnaryNodeOp(a, keywords::shape = newShape(a, axis, indeces)), - indeces_(indeces), - axis_(axis) {} + indeces_(indeces) {} NodeOps forwardOps() { return {NodeOp( @@ -652,7 +628,8 @@ struct SelectNodeOp : public UnaryNodeOp { Shape newShape(Expr a, int axis, const std::vector<size_t>& indeces) { Shape shape = a->shape(); - shape.set(axis, indeces.size()); + axis_ = shape.axis(axis); + shape.set(axis_, indeces.size()); return shape; } @@ -707,8 +684,8 @@ struct TransposeNodeOp : public UnaryNodeOp { Shape newShape(Expr a, Shape permute) { Shape shape = a->shape(); - UTIL_THROW_IF2(shape.size() != permute.size(), - "Shape and transpose axis have different number of dimensions"); + ABORT_IF(shape.size() != permute.size(), + "Shape and transpose axis have different number of dimensions"); for(int i = 0; i < shape.size(); ++i) shape.set(i, a->shape()[permute[i]]); @@ -806,23 +783,27 @@ public: } }; -class TimestepNodeOp : public UnaryNodeOp { +class StepNodeOp : public UnaryNodeOp { private: Expr stepNode_; - size_t step_; + int step_; + int axis_; public: - TimestepNodeOp(Expr a, size_t step) - : UnaryNodeOp(a, keywords::shape = newShape(a)), + StepNodeOp(Expr a, int step, int axis) + : UnaryNodeOp(a, keywords::shape = newShape(a, axis)), stepNode_(a), step_(step) { Node::destroy_ = false; } - Shape newShape(Expr a) { + Shape newShape(Expr a, int axis) { Shape outShape = a->shape(); - outShape.set(2, 1); - outShape.set(3, 1); + + axis_ = outShape.axis(axis); + for(int i = 0; i <= axis_; ++i) + outShape.set(i, 1); + return outShape; } @@ -862,6 +843,7 @@ public: if(!hash_) { hash_ = NaryNodeOp::hash(); boost::hash_combine(hash_, step_); + boost::hash_combine(hash_, axis_); } return hash_; } @@ -869,11 +851,13 @@ public: virtual bool equal(Expr node) { if(!NaryNodeOp::equal(node)) return false; - Ptr<TimestepNodeOp> cnode = std::dynamic_pointer_cast<TimestepNodeOp>(node); + Ptr<StepNodeOp> cnode = std::dynamic_pointer_cast<StepNodeOp>(node); if(!cnode) return false; if(step_ != cnode->step_) return false; + if(axis_ != cnode->axis_) + return false; return true; } }; diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h index e178eeac..707d5b95 100644 --- a/src/kernels/tensor_operators.h +++ b/src/kernels/tensor_operators.h @@ -185,9 +185,7 @@ void Add(Functor functor, Tensor out, Tensor in, float scale = 1.0) { UTIL_THROW_IF2(out->shape().size() != in->shape().size(), "Number of dimensions does not match"); - auto full = out->shape(); - for(int i = 0; i < in->shape().size(); ++i) - full.set(i, std::max(full[i], in->shape()[i])); + auto full = Shape::broadcast({out, in}); int length = out->shape().elements(); @@ -396,16 +394,7 @@ void Add(Functor functor, float scale = 1.0) { cudaSetDevice(out->getDevice()); - UTIL_THROW_IF2(out->shape().size() != in1->shape().size(), - "Number of dimensions does not match"); - UTIL_THROW_IF2(out->shape().size() != in2->shape().size(), - "Number of dimensions does not match"); - - auto full = out->shape(); - for(int i = 0; i < in1->shape().size(); ++i) - full.set(i, std::max(full[i], in1->shape()[i])); - for(int i = 0; i < in2->shape().size(); ++i) - full.set(i, std::max(full[i], in2->shape()[i])); + Shape full = Shape::broadcast({out, in1, in2}); int length = out->shape().elements(); @@ -628,20 +617,7 @@ template <class Functor> void Add(Functor functor, Tensor out, Tensor in1, Tensor in2, Tensor in3) { cudaSetDevice(out->getDevice()); - UTIL_THROW_IF2(out->shape().size() != in1->shape().size(), - "Number of dimensions does not match"); - UTIL_THROW_IF2(out->shape().size() != in2->shape().size(), - "Number of dimensions does not match"); - UTIL_THROW_IF2(out->shape().size() != in3->shape().size(), - "Number of dimensions does not match"); - - auto full = out->shape(); - for(int i = 0; i < in1->shape().size(); ++i) - full.set(i, std::max(full[i], in1->shape()[i])); - for(int i = 0; i < in2->shape().size(); ++i) - full.set(i, std::max(full[i], in2->shape()[i])); - for(int i = 0; i < in3->shape().size(); ++i) - full.set(i, std::max(full[i], in3->shape()[i])); + Shape full = Shape::broadcast({out, in1, in2, in3}); int length = out->shape().elements(); diff --git a/src/rnn/cells.h b/src/rnn/cells.h index 057a7c3c..cea020fa 100644 --- a/src/rnn/cells.h +++ b/src/rnn/cells.h @@ -75,7 +75,7 @@ public: if(inputs.size() == 0) return {}; else if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + input = concatenate(inputs, keywords::axis = -1); else input = inputs.front(); @@ -150,7 +150,7 @@ public: auto Ux = graph->param(prefix + "_Ux", {dimState, dimState}, keywords::init = inits::glorot_uniform); - U_ = concatenate({U, Ux}, keywords::axis = 1); + U_ = concatenate({U, Ux}, keywords::axis = -1); if(dimInput > 0) { auto W = graph->param(prefix + "_W", @@ -159,14 +159,14 @@ public: auto Wx = graph->param(prefix + "_Wx", {dimInput, dimState}, keywords::init = inits::glorot_uniform); - W_ = concatenate({W, Wx}, keywords::axis = 1); + W_ = concatenate({W, Wx}, keywords::axis = -1); } auto b = graph->param( prefix + "_b", {1, 2 * dimState}, keywords::init = inits::zeros); auto bx = graph->param( prefix + "_bx", {1, dimState}, keywords::init = inits::zeros); - b_ = concatenate({b, bx}, keywords::axis = 1); + b_ = concatenate({b, bx}, keywords::axis = -1); // @TODO use this and adjust Amun model type saving and loading // U_ = graph->param(prefix + "_U", {dimState, 3 * dimState}, @@ -204,7 +204,7 @@ public: if(inputs.size() == 0) return {}; else if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + input = concatenate(inputs, keywords::axis = -1); else input = inputs[0]; @@ -311,7 +311,7 @@ public: U_ = U; Ux_ = Ux; } else { - UUx_ = concatenate({U, Ux}, keywords::axis = 1); + UUx_ = concatenate({U, Ux}, keywords::axis = -1); } if(dimInput > 0) { @@ -325,7 +325,7 @@ public: W_ = W; Wx_ = Wx; } else { - WWx_ = concatenate({W, Wx}, keywords::axis = 1); + WWx_ = concatenate({W, Wx}, keywords::axis = -1); } } @@ -342,13 +342,13 @@ public: if(encoder_ && transition_) { auto b0 = graph->constant({1, 2 * dimState}, keywords::init = inits::zeros); - bbx_ = concatenate({b0, bx}, keywords::axis = 1); + bbx_ = concatenate({b0, bx}, keywords::axis = -1); } else { bbx_ = graph->constant({1, 3 * dimState}, keywords::init = inits::zeros); } } else { - bbx_ = concatenate({b, bx}, keywords::axis = 1); + bbx_ = concatenate({b, bx}, keywords::axis = -1); } if(dropout_ > 0.0f) { @@ -396,7 +396,7 @@ public: if(inputs.size() == 0) return {}; else if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + input = concatenate(inputs, keywords::axis = -1); else input = inputs[0]; @@ -418,7 +418,7 @@ public: W = layer_norm(W, W_lns_, W_lnb_, NEMATUS_LN_EPS); Wx = layer_norm(Wx, Wx_lns_, Wx_lnb_, NEMATUS_LN_EPS); - xW = concatenate({W, Wx}, keywords::axis = 1); + xW = concatenate({W, Wx}, keywords::axis = -1); } else { xW = dot(input, WWx_); } @@ -462,7 +462,7 @@ public: Ux = layer_norm(Ux, Ux_lns_, Ux_lnb_, NEMATUS_LN_EPS); } - sU = concatenate({U, Ux}, keywords::axis = 1); + sU = concatenate({U, Ux}, keywords::axis = -1); } else { sU = dot(stateDropped, UUx_); } @@ -553,8 +553,9 @@ public: Expr input; if(inputs.size() == 0) return {}; - else if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + else if(inputs.size() > 1) { + input = concatenate(inputs, keywords::axis = -1); + } else input = inputs.front(); @@ -646,8 +647,9 @@ public: ABORT_IF(inputs.empty(), "Multiplicative LSTM expects input"); Expr input; - if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + if(inputs.size() > 1) { + input = concatenate(inputs, keywords::axis = -1); + } else input = inputs.front(); @@ -742,7 +744,7 @@ public: Expr input; if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + input = concatenate(inputs, keywords::axis = -1); else input = inputs.front(); @@ -826,9 +828,9 @@ public: auto bo = graph->param( prefix + "_bo", {1, dimState}, keywords::init = inits::zeros); - U_ = concatenate({Uf, Ui, Uc, Uo}, keywords::axis = 1); - W_ = concatenate({Wf, Wi, Wc, Wo}, keywords::axis = 1); - b_ = concatenate({bf, bi, bc, bo}, keywords::axis = 1); + U_ = concatenate({Uf, Ui, Uc, Uo}, keywords::axis = -1); + W_ = concatenate({Wf, Wi, Wc, Wo}, keywords::axis = -1); + b_ = concatenate({bf, bi, bc, bo}, keywords::axis = -1); } State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) { @@ -840,7 +842,7 @@ public: Expr input; if(inputs.size() > 1) - input = concatenate(inputs, keywords::axis = 1); + input = concatenate(inputs, keywords::axis = -1); else input = inputs.front(); diff --git a/src/rnn/rnn.h b/src/rnn/rnn.h index 67947db6..62665800 100644 --- a/src/rnn/rnn.h +++ b/src/rnn/rnn.h @@ -68,7 +68,7 @@ private: auto xWs = cell_->applyInput({input}); - size_t timeSteps = input->shape()[2]; + size_t timeSteps = input->shape()[-3]; States outputs; for(size_t i = 0; i < timeSteps; ++i) { @@ -78,12 +78,13 @@ private: j = timeSteps - i - 1; std::vector<Expr> steps(xWs.size()); - std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) { - return step(e, j); - }); + std::transform(xWs.begin(), + xWs.end(), + steps.begin(), + [j](Expr e) { return step(e, j, -3); }); if(mask) - state = cell_->applyState(steps, state, step(mask, j)); + state = cell_->applyState(steps, state, step(mask, j, -3)); else state = cell_->applyState(steps, state); @@ -100,10 +101,11 @@ private: States apply(const Expr input, const Expr mask = nullptr) { auto graph = input->graph(); - int dimBatch = input->shape()[0]; + + int dimBatch = input->shape()[-2]; int dimState = cell_->getOptions()->get<int>("dimState"); - auto output = graph->zeros(keywords::shape = {dimBatch, dimState}); + auto output = graph->zeros(keywords::shape = {1, dimBatch, dimState}); Expr cell = output; State startState{output, cell}; @@ -171,7 +173,7 @@ public: auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); - lazyInput = concatenate(lazyInputs, keywords::axis = 1); + lazyInput = concatenate(lazyInputs, keywords::axis = -1); } auto layerOutput = rnns_[i]->transduce(lazyInput, mask); @@ -197,7 +199,7 @@ public: auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); - lazyInput = concatenate(lazyInputs, keywords::axis = 1); + lazyInput = concatenate(lazyInputs, keywords::axis = -1); } else { lazyInput = layerInput; } @@ -227,7 +229,7 @@ public: auto lazyInputs = cell->getLazyInputs(shared_from_this()); if(!lazyInputs.empty()) { lazyInputs.push_back(layerInput); - lazyInput = concatenate(lazyInputs, keywords::axis = 1); + lazyInput = concatenate(lazyInputs, keywords::axis = -1); } auto layerOutput = rnns_[i]->transduce(lazyInput, States({state}), mask); diff --git a/src/rnn/types.h b/src/rnn/types.h index 9e288d5a..58d4aff2 100644 --- a/src/rnn/types.h +++ b/src/rnn/types.h @@ -48,7 +48,7 @@ public: for(auto s : states_) outputs.push_back(s.output); if(outputs.size() > 1) - return concatenate(outputs, keywords::axis = 2); + return concatenate(outputs, keywords::axis = -3); else return outputs[0]; } @@ -172,7 +172,7 @@ public: outputs.push_back(input->apply(state)); if(outputs.size() > 1) - return concatenate(outputs, keywords::axis = 1); + return concatenate(outputs, keywords::axis = -1); else return outputs[0]; } diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp index 4e7dcdd0..ffec269b 100644 --- a/src/tests/operator_tests.cpp +++ b/src/tests/operator_tests.cpp @@ -90,7 +90,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { std::vector<float> vDiv({2.0f, -1.33333f, 6.0f, -2.66667f}); auto a = graph->constant({2, 2, 1}, keywords::init = inits::from_vector(vA)); - auto b = graph->constant({1, 2, 1}, keywords::init = inits::from_vector(vB)); + auto b = graph->constant({2, 1}, keywords::init = inits::from_vector(vB)); auto add = a + b; auto minus = b - a; @@ -183,7 +183,7 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { auto sp = scalar_product(s2, s2, keywords::axis=0); - auto wa = weighted_average(a, s1, keywords::axis=1); + auto wa = weighted_average(a, s1, keywords::axis=-1); graph->forward(); @@ -237,8 +237,8 @@ TEST_CASE("Expression graph supports basic math operations", "[operator]") { auto in4 = graph->constant({1, 2, 2, 3}, keywords::init=inits::from_value(4)); auto c1out1 = concatenate({in1, in2, in3, in4}, keywords::axis=2); - auto c1out2 = concatenate({in1, in2, in3, in4}, keywords::axis=3); - auto c1out3 = concatenate({in1, in2, in3, in4}, keywords::axis=1); + auto c1out2 = concatenate({in1, in2, in3, in4}, keywords::axis=-1); + auto c1out3 = concatenate({in1, in2, in3, in4}, keywords::axis=-3); auto c1out4 = concatenate({in1, in2, in3, in4}, keywords::axis=0); graph->forward(); diff --git a/src/tests/rnn_tests.cpp b/src/tests/rnn_tests.cpp index 5e763f29..14bd3805 100644 --- a/src/tests/rnn_tests.cpp +++ b/src/tests/rnn_tests.cpp @@ -86,7 +86,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") { bool layerNorm=false, bool skip=false) { - int dimEmb = input->shape()[1]; + int dimEmb = input->shape()[-1]; int first, second; if(type == "bidirectional" || type == "alternating") { @@ -154,7 +154,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") { auto context = concatenate({rnnFw->transduce(input, mask), rnnBw->transduce(input, mask)}, - axis = 1); + axis = input->shape().size() - 1); if(second > 0) { // add more layers (unidirectional) by transducing the output of the @@ -227,6 +227,7 @@ TEST_CASE("Model components, RNN etc.", "[model]") { }); contextSum1->val()->get(values); + CHECK( std::equal(values.begin(), values.end(), vContextSum1.begin(), floatApprox) ); |