From 0a62d1e2d051105cb25a83cc3ab348e4b19c50df Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Tue, 22 Jan 2019 15:08:32 -0800 Subject: changed index operations' parameter lists to match PyTorch parameter order (axis before arg) --- src/common/shape.h | 2 +- src/graph/expression_operators.cpp | 30 ++++++++++++++-------------- src/graph/expression_operators.h | 25 ++++++++++++------------ src/graph/node_operators_binary.h | 6 +++--- src/graph/node_operators_unary.h | 6 +++--- src/models/transformer.h | 4 ++-- src/rnn/rnn.h | 4 ++-- src/tests/operator_tests.cpp | 40 +++++++++++++++++++------------------- 8 files changed, 59 insertions(+), 58 deletions(-) diff --git a/src/common/shape.h b/src/common/shape.h index 20fb2607..c8a4bdd3 100755 --- a/src/common/shape.h +++ b/src/common/shape.h @@ -17,7 +17,7 @@ struct Slice // Python-like slice/index descriptor Slice(int b, int e, int s) : begin(b), end(e), stride(s) {} Slice(int b, int e) : Slice(b, e, 1) {} Slice() : Slice(0, END) {} - Slice(int i) : Slice(i, i + 1) {} + explicit Slice(int i) : Slice(i, i + 1) {} Slice(const Slice& other) : Slice(other.begin, other.end, other.stride) {} const Slice& operator=(const Slice& other) { begin = other.begin; end = other.end; stride = other.stride; return *this; } const Slice& operator=(int i) { begin = i; end = i + 1; stride = 1; return *this; } diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 81310b43..db820062 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -246,17 +246,17 @@ Expr constant_like(Expr a, const NodeInitializer& init) { } // gather() -- gather arbitrary elements along an axis; batched or non-batched -Expr gather(Expr a, Expr indices, int axis) { - return Expression(a, indices, axis); +Expr gather(Expr a, int axis, Expr indices) { + return Expression(a, axis, indices); } // index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector) -Expr index_select(Expr a, Expr indices, int axis) { +Expr index_select(Expr a, int axis, Expr indices) { ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor"); // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor. auto rank = a->shape().size(); if (rank == 2) { - if (axis == 0) + if (axis == 0 || axis == 2) return Expression(a, indices); else if (axis == -1 || axis == 1) return Expression(a, indices); @@ -266,29 +266,29 @@ Expr index_select(Expr a, Expr indices, int axis) { shape.resize(a->shape().size()); shape.set(axis, indices->shape()[0]); indices = reshape(indices, shape); // move index to axis - return gather(a, indices, axis); + return gather(a, axis, indices); } -Expr index_select(Expr a, const std::vector& indices, int axis) { +Expr index_select(Expr a, int axis, const std::vector& indices) { auto indexExpr = a->graph()->indices(indices); - return index_select(a, indexExpr, axis); + return index_select(a, axis, indexExpr); } -static Expr sliceCopy(Expr a, const Slice& slice, int axis) { // copy a Slice via gather() +static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather() ABORT_IF(slice.stride < 0, "Negative strides are not supported yet"); ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they? std::vector indices; indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1); for (int i = slice.begin; i < slice.end; i += slice.stride) indices.push_back((IndexType)i); - return gather(a, a->graph()->indices(indices, a, axis), axis); + return gather(a, axis, a->graph()->indices(indices, a, axis)); } -static Expr sliceView(Expr a, const Slice& slice, int axis) { // view a slice (must be memory-consecutive) - return Expression(a, slice, axis); +static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive) + return Expression(a, axis, slice); } // slice() -- gather a slice along an axis (step size > 1 allowed) -Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but with axis parameter +Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter const auto& shape = a->shape(); axis = shape.axis(axis); // normalize negative axis slice = shape.slice(slice, axis); // normalize negative slice values @@ -296,13 +296,13 @@ Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but return a; // it's a no-op #if 1 // until strided views are supported, non-consecutive slices are implemented via gather() if (slice.stride != 1) - return sliceCopy(a, slice, axis); + return sliceCopy(a, axis, slice); for (int i = 0; i < axis; ++i) { if (shape[i] != 1) // this makes it non-consecutive - return sliceCopy(a, slice, axis); + return sliceCopy(a, axis, slice); } #endif - return sliceView(a, slice, axis); + return sliceView(a, axis, slice); } Expr sum(Expr a, int ax) { diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 5902f2af..58149bde 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -139,34 +139,35 @@ Expr flatten_2d(Expr a); Expr stopGradient(Expr a); -Expr gather(Expr a, Expr indices, int axis); +Expr gather(Expr a, int axis, Expr indices); -Expr index_select(Expr a, Expr indices, int axis); +// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr... +Expr index_select(Expr a, int axis, Expr indices); // convenience wrappers for index_select() -Expr index_select(Expr a, const std::vector& indices, int axis); +Expr index_select(Expr a, int axis, const std::vector& indices); static inline Expr rows(Expr a, Expr indices) { - return index_select(a, indices, 0); + return index_select(a, 0, indices); } static inline Expr rows(Expr a, const std::vector& indexVector) { - return index_select(a, indexVector, 0); + return index_select(a, 0, indexVector); } static inline Expr cols(Expr a, Expr indices) { - return index_select(a, indices, -1); + return index_select(a, -1, indices); } static inline Expr cols(Expr a, const std::vector& indexVector) { - return index_select(a, indexVector, -1); + return index_select(a, -1, indexVector); } -Expr slice(Expr a, Slice slice, int axis); +Expr slice(Expr a, int axis, Slice slice); // convenience wrappers for slice() -static inline Expr step(Expr a, int step, int axis) { // @TODO: name is too narrow - return slice(a, Slice(step), axis); +static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step() + return slice(a, axis, Slice(index)); } -static inline Expr narrow(Expr a, size_t start, size_t length, int axis) { // PyTorch name - return slice(a, Slice((int)start, (int)(start + length)), axis); +static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name + return slice(a, axis, Slice((int)start, (int)(start + length))); } /*********************************************************/ diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 22109fbc..10e2ca76 100755 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -576,8 +576,8 @@ struct RowsNodeOp : public NaryNodeOp { // @TODO: The current implementation does not support batched indices (third scenario above). // I.e. all axes of 'indices' except 'axis' must have dimension 1. struct GatherNodeOp : public NaryNodeOp { - GatherNodeOp(Expr a, Expr indices, int axis) - : NaryNodeOp({a, indices}, newShape(a, indices, axis), a->value_type()), + GatherNodeOp(Expr a, int axis, Expr indices) + : NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()), axis_(a->shape().axis(axis)) { matchOrAbort(indices->value_type()); } @@ -592,7 +592,7 @@ struct GatherNodeOp : public NaryNodeOp { Insert(child(0)->grad(), adj_, child(1)->val(), axis_))}; } - Shape newShape(Expr a, Expr indices, int axis) { + Shape newShape(Expr a, int axis, Expr indices) { Shape shape = a->shape(); axis = shape.axis(axis); auto rank = shape.size(); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 51433f93..7dbaec46 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -745,15 +745,15 @@ private: size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive) public: - SliceViewNodeOp(Expr a, Slice slice, int axis) - : UnaryNodeOp(a, newShape(a, slice, axis), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) { + SliceViewNodeOp(Expr a, int axis, Slice slice) + : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) { Node::destroy_ = false; auto byteStride = a->shape().stride(axis) * sizeOf(value_type()); byteOffset_ = slice.begin * byteStride; byteSize_ = shape()[axis] * byteStride; } - static Shape newShape(Expr a, Slice& slice, int& axis) { // note: normalizes slice and axis in-place + static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place const auto& shape = a->shape(); axis = shape.axis(axis); // normalize negative axis slice = shape.slice(slice, axis); // normalize negative slice values diff --git a/src/models/transformer.h b/src/models/transformer.h index 3c540c12..968d481b 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -178,7 +178,7 @@ public: void collectOneHead(Expr weights, int dimBeam) { // select first head, this is arbitrary as the choice does not really matter - auto head0 = index_select(weights, 0, -3); + auto head0 = slice(weights, -3, 0); int dimBatchBeam = head0->shape()[-4]; int srcWords = head0->shape()[-1]; @@ -194,7 +194,7 @@ public: // @TODO: make splitting obsolete alignments_.clear(); for(int i = 0; i < trgWords; ++i) { - alignments_.push_back(marian::step(head0, i, -1)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1] + alignments_.push_back(slice(head0, -1, i)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1] } } diff --git a/src/rnn/rnn.h b/src/rnn/rnn.h index 74a535d3..c9131d87 100755 --- a/src/rnn/rnn.h +++ b/src/rnn/rnn.h @@ -75,11 +75,11 @@ private: std::vector steps(xWs.size()); std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) { - return step(e, j, -3); + return slice(e, -3, j); }); if(mask) - state = cell_->applyState(steps, state, step(mask, j, -3)); + state = cell_->applyState(steps, state, slice(mask, -3, j)); else state = cell_->applyState(steps, state); diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp index 34d8b03c..08fdd8de 100755 --- a/src/tests/operator_tests.cpp +++ b/src/tests/operator_tests.cpp @@ -643,13 +643,13 @@ void tests(DeviceType device) { std::vector vS3({7, -8, 9, -10, 11, -12}); auto A = graph->param("4x3", {4,3}, inits::from_vector(vA)); - auto B1a = index_select(A, IndexVector({0}), 0); // always uses gather() - auto B1b = step(A, 0, 0); // memory-consecutive view - auto B2 = step(A, 0, 1); // not memory-consecutive - auto B3 = step(A, 1, -1); - auto B4a = index_select(A, IndexVector({0, 1}), 0); - auto B4b = slice(A, Slice(0, 2), 0); // this is memory-consecutive - auto B5 = slice(A, Slice(0, 4), 0); // this is a no-op + auto B1a = index_select(A, 0, IndexVector({0})); // always uses gather() + auto B1b = slice(A, 0, 0); // memory-consecutive view + auto B2 = slice(A, 1, 0); // not memory-consecutive + auto B3 = slice(A, -1, 1); + auto B4a = index_select(A, 0, IndexVector({0, 1})); + auto B4b = slice(A, 0, Slice(0, 2)); // this is memory-consecutive + auto B5 = slice(A, 0, Slice(0, 4)); // this is a no-op CHECK(B1a->type() == "rows"); // actually optimized to rows() CHECK(B1b->type() == "sliceView"); // must use view CHECK(B2->type() == "gather"); // cannot use view @@ -658,21 +658,21 @@ void tests(DeviceType device) { CHECK(B5.get() == A.get()); // must be no-op auto C = graph->param("2x3x2", {2, 3, 2}, inits::from_vector(vC)); - auto D1 = step(C, 0, 0); - auto D2 = step(C, 2, -2); - auto D3 = index_select(C, IndexVector({0, 2}), 1); // C[:,(0,2),:] + auto D1 = slice(C, 0, 0); + auto D2 = slice(C, -2, 2); + auto D3 = index_select(C, 1, IndexVector({0, 2})); // C[:,(0,2),:] CHECK(D1->type() == "sliceView"); CHECK(D2->type() == "gather"); // enable this once gather() supports batched indices: - //auto D4 = gather(C, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]] - // inits::from_vector(std::vector{ - // 2, 1, - // 0, 2 }), - // Type::uint32), 1); + //auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]] + // inits::from_vector(std::vector{ + // 2, 1, + // 0, 2 }), + // Type::uint32)); - auto S1 = step(A, 2, 0); - auto S2 = narrow(A, 1, 2, 0); - auto S3 = slice(A, Slice(-2, Slice::END), 0); + auto S1 = slice(A, 0, 2); + auto S2 = narrow(A, 0, 1, 2); + auto S3 = slice(A, 0, Slice(-2, Slice::END)); graph->forward(); @@ -703,9 +703,9 @@ void tests(DeviceType device) { auto A = graph->param("4x3", {4, 3}, inits::from_vector(vA)); auto B1 = rows(A, indices); - auto B2 = gather(A, graph->indices(indices, A, 0), 0); + auto B2 = gather(A, 0, graph->indices(indices, A, 0)); auto C1 = cols(A, indices); - auto C2 = gather(A, graph->indices(indices, A, 1), 1); + auto C2 = gather(A, 1, graph->indices(indices, A, 1)); graph->forward(); CHECK(B1->shape() == B2->shape()); -- cgit v1.2.3 From 7ae9709043cdcc4f9bf38e9519f06e9eccaf58eb Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Tue, 22 Jan 2019 15:23:11 -0800 Subject: fixed a typo --- src/graph/expression_operators.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index db820062..826bd9f0 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -256,7 +256,7 @@ Expr index_select(Expr a, int axis, Expr indices) { // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor. auto rank = a->shape().size(); if (rank == 2) { - if (axis == 0 || axis == 2) + if (axis == 0 || axis == -2) return Expression(a, indices); else if (axis == -1 || axis == 1) return Expression(a, indices); -- cgit v1.2.3