diff options
author | Frank Seide <fseide@microsoft.com> | 2019-01-23 02:25:53 +0300 |
---|---|---|
committer | Frank Seide <fseide@microsoft.com> | 2019-01-23 02:25:53 +0300 |
commit | 49668f1587b2bf08b182d4753a3ff48f76f3403c (patch) | |
tree | 44dbbde4409859c9841fb314e8c110853b429a73 /src/graph | |
parent | c1c175f99522da1611c0847c6fc3152d423a24fa (diff) | |
parent | 7ae9709043cdcc4f9bf38e9519f06e9eccaf58eb (diff) |
Merge branch 'fseide/indexops' into fseide/factoredembeddings
Diffstat (limited to 'src/graph')
-rwxr-xr-x | src/graph/expression_operators.cpp | 30 | ||||
-rwxr-xr-x | src/graph/expression_operators.h | 25 | ||||
-rwxr-xr-x | src/graph/node_operators_binary.h | 6 | ||||
-rwxr-xr-x | src/graph/node_operators_unary.h | 6 |
4 files changed, 34 insertions, 33 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 41c86f74..6a07611d 100755 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -246,17 +246,17 @@ Expr constant_like(Expr a, const NodeInitializer& init) { } // gather() -- gather arbitrary elements along an axis; batched or non-batched -Expr gather(Expr a, Expr indices, int axis) { - return Expression<GatherNodeOp>(a, indices, axis); +Expr gather(Expr a, int axis, Expr indices) { + return Expression<GatherNodeOp>(a, axis, indices); } // index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector) -Expr index_select(Expr a, Expr indices, int axis) { +Expr index_select(Expr a, int axis, Expr indices) { ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor"); // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor. auto rank = a->shape().size(); if (rank == 2) { - if (axis == 0) + if (axis == 0 || axis == -2) return Expression<RowsNodeOp>(a, indices); else if (axis == -1 || axis == 1) return Expression<ColsNodeOp>(a, indices); @@ -266,29 +266,29 @@ Expr index_select(Expr a, Expr indices, int axis) { shape.resize(a->shape().size()); shape.set(axis, indices->shape()[0]); indices = reshape(indices, shape); // move index to axis - return gather(a, indices, axis); + return gather(a, axis, indices); } -Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis) { +Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) { auto indexExpr = a->graph()->indices(indices); - return index_select(a, indexExpr, axis); + return index_select(a, axis, indexExpr); } -static Expr sliceCopy(Expr a, const Slice& slice, int axis) { // copy a Slice via gather() +static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather() ABORT_IF(slice.stride < 0, "Negative strides are not supported yet"); ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they? std::vector<IndexType> indices; indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1); for (int i = slice.begin; i < slice.end; i += slice.stride) indices.push_back((IndexType)i); - return gather(a, a->graph()->indices(indices, a, axis), axis); + return gather(a, axis, a->graph()->indices(indices, a, axis)); } -static Expr sliceView(Expr a, const Slice& slice, int axis) { // view a slice (must be memory-consecutive) - return Expression<SliceViewNodeOp>(a, slice, axis); +static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive) + return Expression<SliceViewNodeOp>(a, axis, slice); } // slice() -- gather a slice along an axis (step size > 1 allowed) -Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but with axis parameter +Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter const auto& shape = a->shape(); axis = shape.axis(axis); // normalize negative axis slice = shape.slice(slice, axis); // normalize negative slice values @@ -296,13 +296,13 @@ Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but return a; // it's a no-op #if 1 // until strided views are supported, non-consecutive slices are implemented via gather() if (slice.stride != 1) - return sliceCopy(a, slice, axis); + return sliceCopy(a, axis, slice); for (int i = 0; i < axis; ++i) { if (shape[i] != 1) // this makes it non-consecutive - return sliceCopy(a, slice, axis); + return sliceCopy(a, axis, slice); } #endif - return sliceView(a, slice, axis); + return sliceView(a, axis, slice); } Expr sum(Expr a, int ax) { diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index e8c7d646..78aed834 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -139,34 +139,35 @@ Expr flatten_2d(Expr a); Expr stopGradient(Expr a); -Expr gather(Expr a, Expr indices, int axis); +Expr gather(Expr a, int axis, Expr indices); -Expr index_select(Expr a, Expr indices, int axis); +// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr... +Expr index_select(Expr a, int axis, Expr indices); // convenience wrappers for index_select() -Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis); +Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices); static inline Expr rows(Expr a, Expr indices) { - return index_select(a, indices, 0); + return index_select(a, 0, indices); } static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) { - return index_select(a, indexVector, 0); + return index_select(a, 0, indexVector); } static inline Expr cols(Expr a, Expr indices) { - return index_select(a, indices, -1); + return index_select(a, -1, indices); } static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) { - return index_select(a, indexVector, -1); + return index_select(a, -1, indexVector); } -Expr slice(Expr a, Slice slice, int axis); +Expr slice(Expr a, int axis, Slice slice); // convenience wrappers for slice() -static inline Expr step(Expr a, int step, int axis) { // @TODO: name is too narrow - return slice(a, Slice(step), axis); +static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step() + return slice(a, axis, Slice(index)); } -static inline Expr narrow(Expr a, size_t start, size_t length, int axis) { // PyTorch name - return slice(a, Slice((int)start, (int)(start + length)), axis); +static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name + return slice(a, axis, Slice((int)start, (int)(start + length))); } /*********************************************************/ diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index b61ff730..7d090823 100755 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -576,8 +576,8 @@ struct RowsNodeOp : public NaryNodeOp { // @TODO: The current implementation does not support batched indices (third scenario above). // I.e. all axes of 'indices' except 'axis' must have dimension 1. struct GatherNodeOp : public NaryNodeOp { - GatherNodeOp(Expr a, Expr indices, int axis) - : NaryNodeOp({a, indices}, newShape(a, indices, axis), a->value_type()), + GatherNodeOp(Expr a, int axis, Expr indices) + : NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()), axis_(a->shape().axis(axis)) { matchOrAbort<IndexType>(indices->value_type()); } @@ -592,7 +592,7 @@ struct GatherNodeOp : public NaryNodeOp { Insert(child(0)->grad(), adj_, child(1)->val(), axis_))}; } - Shape newShape(Expr a, Expr indices, int axis) { + Shape newShape(Expr a, int axis, Expr indices) { Shape shape = a->shape(); axis = shape.axis(axis); auto rank = shape.size(); diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 3530d2bf..6dd90faf 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -761,15 +761,15 @@ private: size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive) public: - SliceViewNodeOp(Expr a, Slice slice, int axis) - : UnaryNodeOp(a, newShape(a, slice, axis), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) { + SliceViewNodeOp(Expr a, int axis, Slice slice) + : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) { Node::destroy_ = false; auto byteStride = a->shape().stride(axis) * sizeOf(value_type()); byteOffset_ = slice.begin * byteStride; byteSize_ = shape()[axis] * byteStride; } - static Shape newShape(Expr a, Slice& slice, int& axis) { // note: normalizes slice and axis in-place + static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place const auto& shape = a->shape(); axis = shape.axis(axis); // normalize negative axis slice = shape.slice(slice, axis); // normalize negative slice values |