Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src/graph
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-01-23 02:25:53 +0300
committerFrank Seide <fseide@microsoft.com>2019-01-23 02:25:53 +0300
commit49668f1587b2bf08b182d4753a3ff48f76f3403c (patch)
tree44dbbde4409859c9841fb314e8c110853b429a73 /src/graph
parentc1c175f99522da1611c0847c6fc3152d423a24fa (diff)
parent7ae9709043cdcc4f9bf38e9519f06e9eccaf58eb (diff)
Merge branch 'fseide/indexops' into fseide/factoredembeddings
Diffstat (limited to 'src/graph')
-rwxr-xr-xsrc/graph/expression_operators.cpp30
-rwxr-xr-xsrc/graph/expression_operators.h25
-rwxr-xr-xsrc/graph/node_operators_binary.h6
-rwxr-xr-xsrc/graph/node_operators_unary.h6
4 files changed, 34 insertions, 33 deletions
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 41c86f74..6a07611d 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -246,17 +246,17 @@ Expr constant_like(Expr a, const NodeInitializer& init) {
}
// gather() -- gather arbitrary elements along an axis; batched or non-batched
-Expr gather(Expr a, Expr indices, int axis) {
- return Expression<GatherNodeOp>(a, indices, axis);
+Expr gather(Expr a, int axis, Expr indices) {
+ return Expression<GatherNodeOp>(a, axis, indices);
}
// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
-Expr index_select(Expr a, Expr indices, int axis) {
+Expr index_select(Expr a, int axis, Expr indices) {
ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
// We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
auto rank = a->shape().size();
if (rank == 2) {
- if (axis == 0)
+ if (axis == 0 || axis == -2)
return Expression<RowsNodeOp>(a, indices);
else if (axis == -1 || axis == 1)
return Expression<ColsNodeOp>(a, indices);
@@ -266,29 +266,29 @@ Expr index_select(Expr a, Expr indices, int axis) {
shape.resize(a->shape().size());
shape.set(axis, indices->shape()[0]);
indices = reshape(indices, shape); // move index to axis
- return gather(a, indices, axis);
+ return gather(a, axis, indices);
}
-Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis) {
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) {
auto indexExpr = a->graph()->indices(indices);
- return index_select(a, indexExpr, axis);
+ return index_select(a, axis, indexExpr);
}
-static Expr sliceCopy(Expr a, const Slice& slice, int axis) { // copy a Slice via gather()
+static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather()
ABORT_IF(slice.stride < 0, "Negative strides are not supported yet");
ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they?
std::vector<IndexType> indices;
indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1);
for (int i = slice.begin; i < slice.end; i += slice.stride)
indices.push_back((IndexType)i);
- return gather(a, a->graph()->indices(indices, a, axis), axis);
+ return gather(a, axis, a->graph()->indices(indices, a, axis));
}
-static Expr sliceView(Expr a, const Slice& slice, int axis) { // view a slice (must be memory-consecutive)
- return Expression<SliceViewNodeOp>(a, slice, axis);
+static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive)
+ return Expression<SliceViewNodeOp>(a, axis, slice);
}
// slice() -- gather a slice along an axis (step size > 1 allowed)
-Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but with axis parameter
+Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter
const auto& shape = a->shape();
axis = shape.axis(axis); // normalize negative axis
slice = shape.slice(slice, axis); // normalize negative slice values
@@ -296,13 +296,13 @@ Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but
return a; // it's a no-op
#if 1 // until strided views are supported, non-consecutive slices are implemented via gather()
if (slice.stride != 1)
- return sliceCopy(a, slice, axis);
+ return sliceCopy(a, axis, slice);
for (int i = 0; i < axis; ++i) {
if (shape[i] != 1) // this makes it non-consecutive
- return sliceCopy(a, slice, axis);
+ return sliceCopy(a, axis, slice);
}
#endif
- return sliceView(a, slice, axis);
+ return sliceView(a, axis, slice);
}
Expr sum(Expr a, int ax) {
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index e8c7d646..78aed834 100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -139,34 +139,35 @@ Expr flatten_2d(Expr a);
Expr stopGradient(Expr a);
-Expr gather(Expr a, Expr indices, int axis);
+Expr gather(Expr a, int axis, Expr indices);
-Expr index_select(Expr a, Expr indices, int axis);
+// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr...
+Expr index_select(Expr a, int axis, Expr indices);
// convenience wrappers for index_select()
-Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis);
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);
static inline Expr rows(Expr a, Expr indices) {
- return index_select(a, indices, 0);
+ return index_select(a, 0, indices);
}
static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
- return index_select(a, indexVector, 0);
+ return index_select(a, 0, indexVector);
}
static inline Expr cols(Expr a, Expr indices) {
- return index_select(a, indices, -1);
+ return index_select(a, -1, indices);
}
static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
- return index_select(a, indexVector, -1);
+ return index_select(a, -1, indexVector);
}
-Expr slice(Expr a, Slice slice, int axis);
+Expr slice(Expr a, int axis, Slice slice);
// convenience wrappers for slice()
-static inline Expr step(Expr a, int step, int axis) { // @TODO: name is too narrow
- return slice(a, Slice(step), axis);
+static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step()
+ return slice(a, axis, Slice(index));
}
-static inline Expr narrow(Expr a, size_t start, size_t length, int axis) { // PyTorch name
- return slice(a, Slice((int)start, (int)(start + length)), axis);
+static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name
+ return slice(a, axis, Slice((int)start, (int)(start + length)));
}
/*********************************************************/
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index b61ff730..7d090823 100755
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -576,8 +576,8 @@ struct RowsNodeOp : public NaryNodeOp {
// @TODO: The current implementation does not support batched indices (third scenario above).
// I.e. all axes of 'indices' except 'axis' must have dimension 1.
struct GatherNodeOp : public NaryNodeOp {
- GatherNodeOp(Expr a, Expr indices, int axis)
- : NaryNodeOp({a, indices}, newShape(a, indices, axis), a->value_type()),
+ GatherNodeOp(Expr a, int axis, Expr indices)
+ : NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
@@ -592,7 +592,7 @@ struct GatherNodeOp : public NaryNodeOp {
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
- Shape newShape(Expr a, Expr indices, int axis) {
+ Shape newShape(Expr a, int axis, Expr indices) {
Shape shape = a->shape();
axis = shape.axis(axis);
auto rank = shape.size();
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 3530d2bf..6dd90faf 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -761,15 +761,15 @@ private:
size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive)
public:
- SliceViewNodeOp(Expr a, Slice slice, int axis)
- : UnaryNodeOp(a, newShape(a, slice, axis), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
+ SliceViewNodeOp(Expr a, int axis, Slice slice)
+ : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
Node::destroy_ = false;
auto byteStride = a->shape().stride(axis) * sizeOf(value_type());
byteOffset_ = slice.begin * byteStride;
byteSize_ = shape()[axis] * byteStride;
}
- static Shape newShape(Expr a, Slice& slice, int& axis) { // note: normalizes slice and axis in-place
+ static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place
const auto& shape = a->shape();
axis = shape.axis(axis); // normalize negative axis
slice = shape.slice(slice, axis); // normalize negative slice values