Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Seide <fseide@microsoft.com>2019-01-23 02:25:53 +0300
committerFrank Seide <fseide@microsoft.com>2019-01-23 02:25:53 +0300
commit49668f1587b2bf08b182d4753a3ff48f76f3403c (patch)
tree44dbbde4409859c9841fb314e8c110853b429a73
parentc1c175f99522da1611c0847c6fc3152d423a24fa (diff)
parent7ae9709043cdcc4f9bf38e9519f06e9eccaf58eb (diff)
Merge branch 'fseide/indexops' into fseide/factoredembeddings
-rwxr-xr-xsrc/common/shape.h2
-rwxr-xr-xsrc/graph/expression_operators.cpp30
-rwxr-xr-xsrc/graph/expression_operators.h25
-rwxr-xr-xsrc/graph/node_operators_binary.h6
-rwxr-xr-xsrc/graph/node_operators_unary.h6
-rwxr-xr-xsrc/models/transformer.h4
-rwxr-xr-xsrc/rnn/rnn.h4
-rwxr-xr-xsrc/tests/operator_tests.cpp40
8 files changed, 59 insertions, 58 deletions
diff --git a/src/common/shape.h b/src/common/shape.h
index 20fb2607..c8a4bdd3 100755
--- a/src/common/shape.h
+++ b/src/common/shape.h
@@ -17,7 +17,7 @@ struct Slice // Python-like slice/index descriptor
Slice(int b, int e, int s) : begin(b), end(e), stride(s) {}
Slice(int b, int e) : Slice(b, e, 1) {}
Slice() : Slice(0, END) {}
- Slice(int i) : Slice(i, i + 1) {}
+ explicit Slice(int i) : Slice(i, i + 1) {}
Slice(const Slice& other) : Slice(other.begin, other.end, other.stride) {}
const Slice& operator=(const Slice& other) { begin = other.begin; end = other.end; stride = other.stride; return *this; }
const Slice& operator=(int i) { begin = i; end = i + 1; stride = 1; return *this; }
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 41c86f74..6a07611d 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -246,17 +246,17 @@ Expr constant_like(Expr a, const NodeInitializer& init) {
}
// gather() -- gather arbitrary elements along an axis; batched or non-batched
-Expr gather(Expr a, Expr indices, int axis) {
- return Expression<GatherNodeOp>(a, indices, axis);
+Expr gather(Expr a, int axis, Expr indices) {
+ return Expression<GatherNodeOp>(a, axis, indices);
}
// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
-Expr index_select(Expr a, Expr indices, int axis) {
+Expr index_select(Expr a, int axis, Expr indices) {
ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
// We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
auto rank = a->shape().size();
if (rank == 2) {
- if (axis == 0)
+ if (axis == 0 || axis == -2)
return Expression<RowsNodeOp>(a, indices);
else if (axis == -1 || axis == 1)
return Expression<ColsNodeOp>(a, indices);
@@ -266,29 +266,29 @@ Expr index_select(Expr a, Expr indices, int axis) {
shape.resize(a->shape().size());
shape.set(axis, indices->shape()[0]);
indices = reshape(indices, shape); // move index to axis
- return gather(a, indices, axis);
+ return gather(a, axis, indices);
}
-Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis) {
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) {
auto indexExpr = a->graph()->indices(indices);
- return index_select(a, indexExpr, axis);
+ return index_select(a, axis, indexExpr);
}
-static Expr sliceCopy(Expr a, const Slice& slice, int axis) { // copy a Slice via gather()
+static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather()
ABORT_IF(slice.stride < 0, "Negative strides are not supported yet");
ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they?
std::vector<IndexType> indices;
indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1);
for (int i = slice.begin; i < slice.end; i += slice.stride)
indices.push_back((IndexType)i);
- return gather(a, a->graph()->indices(indices, a, axis), axis);
+ return gather(a, axis, a->graph()->indices(indices, a, axis));
}
-static Expr sliceView(Expr a, const Slice& slice, int axis) { // view a slice (must be memory-consecutive)
- return Expression<SliceViewNodeOp>(a, slice, axis);
+static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive)
+ return Expression<SliceViewNodeOp>(a, axis, slice);
}
// slice() -- gather a slice along an axis (step size > 1 allowed)
-Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but with axis parameter
+Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter
const auto& shape = a->shape();
axis = shape.axis(axis); // normalize negative axis
slice = shape.slice(slice, axis); // normalize negative slice values
@@ -296,13 +296,13 @@ Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but
return a; // it's a no-op
#if 1 // until strided views are supported, non-consecutive slices are implemented via gather()
if (slice.stride != 1)
- return sliceCopy(a, slice, axis);
+ return sliceCopy(a, axis, slice);
for (int i = 0; i < axis; ++i) {
if (shape[i] != 1) // this makes it non-consecutive
- return sliceCopy(a, slice, axis);
+ return sliceCopy(a, axis, slice);
}
#endif
- return sliceView(a, slice, axis);
+ return sliceView(a, axis, slice);
}
Expr sum(Expr a, int ax) {
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index e8c7d646..78aed834 100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -139,34 +139,35 @@ Expr flatten_2d(Expr a);
Expr stopGradient(Expr a);
-Expr gather(Expr a, Expr indices, int axis);
+Expr gather(Expr a, int axis, Expr indices);
-Expr index_select(Expr a, Expr indices, int axis);
+// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr...
+Expr index_select(Expr a, int axis, Expr indices);
// convenience wrappers for index_select()
-Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis);
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);
static inline Expr rows(Expr a, Expr indices) {
- return index_select(a, indices, 0);
+ return index_select(a, 0, indices);
}
static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
- return index_select(a, indexVector, 0);
+ return index_select(a, 0, indexVector);
}
static inline Expr cols(Expr a, Expr indices) {
- return index_select(a, indices, -1);
+ return index_select(a, -1, indices);
}
static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
- return index_select(a, indexVector, -1);
+ return index_select(a, -1, indexVector);
}
-Expr slice(Expr a, Slice slice, int axis);
+Expr slice(Expr a, int axis, Slice slice);
// convenience wrappers for slice()
-static inline Expr step(Expr a, int step, int axis) { // @TODO: name is too narrow
- return slice(a, Slice(step), axis);
+static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step()
+ return slice(a, axis, Slice(index));
}
-static inline Expr narrow(Expr a, size_t start, size_t length, int axis) { // PyTorch name
- return slice(a, Slice((int)start, (int)(start + length)), axis);
+static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name
+ return slice(a, axis, Slice((int)start, (int)(start + length)));
}
/*********************************************************/
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index b61ff730..7d090823 100755
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -576,8 +576,8 @@ struct RowsNodeOp : public NaryNodeOp {
// @TODO: The current implementation does not support batched indices (third scenario above).
// I.e. all axes of 'indices' except 'axis' must have dimension 1.
struct GatherNodeOp : public NaryNodeOp {
- GatherNodeOp(Expr a, Expr indices, int axis)
- : NaryNodeOp({a, indices}, newShape(a, indices, axis), a->value_type()),
+ GatherNodeOp(Expr a, int axis, Expr indices)
+ : NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
@@ -592,7 +592,7 @@ struct GatherNodeOp : public NaryNodeOp {
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
- Shape newShape(Expr a, Expr indices, int axis) {
+ Shape newShape(Expr a, int axis, Expr indices) {
Shape shape = a->shape();
axis = shape.axis(axis);
auto rank = shape.size();
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 3530d2bf..6dd90faf 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -761,15 +761,15 @@ private:
size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive)
public:
- SliceViewNodeOp(Expr a, Slice slice, int axis)
- : UnaryNodeOp(a, newShape(a, slice, axis), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
+ SliceViewNodeOp(Expr a, int axis, Slice slice)
+ : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
Node::destroy_ = false;
auto byteStride = a->shape().stride(axis) * sizeOf(value_type());
byteOffset_ = slice.begin * byteStride;
byteSize_ = shape()[axis] * byteStride;
}
- static Shape newShape(Expr a, Slice& slice, int& axis) { // note: normalizes slice and axis in-place
+ static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place
const auto& shape = a->shape();
axis = shape.axis(axis); // normalize negative axis
slice = shape.slice(slice, axis); // normalize negative slice values
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 6a7db643..f6aa4ff1 100755
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -177,7 +177,7 @@ public:
void collectOneHead(Expr weights, int dimBeam) {
// select first head, this is arbitrary as the choice does not really matter
- auto head0 = index_select(weights, 0, -3);
+ auto head0 = slice(weights, -3, 0);
int dimBatchBeam = head0->shape()[-4];
int srcWords = head0->shape()[-1];
@@ -193,7 +193,7 @@ public:
// @TODO: make splitting obsolete
alignments_.clear();
for(int i = 0; i < trgWords; ++i) {
- alignments_.push_back(marian::step(head0, i, -1)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1]
+ alignments_.push_back(slice(head0, -1, i)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1]
}
}
diff --git a/src/rnn/rnn.h b/src/rnn/rnn.h
index 74a535d3..c9131d87 100755
--- a/src/rnn/rnn.h
+++ b/src/rnn/rnn.h
@@ -75,11 +75,11 @@ private:
std::vector<Expr> steps(xWs.size());
std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) {
- return step(e, j, -3);
+ return slice(e, -3, j);
});
if(mask)
- state = cell_->applyState(steps, state, step(mask, j, -3));
+ state = cell_->applyState(steps, state, slice(mask, -3, j));
else
state = cell_->applyState(steps, state);
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp
index 09e10a6d..2607f41f 100755
--- a/src/tests/operator_tests.cpp
+++ b/src/tests/operator_tests.cpp
@@ -667,13 +667,13 @@ void tests(DeviceType device) {
std::vector<float> vS3({7, -8, 9, -10, 11, -12});
auto A = graph->param("4x3", {4,3}, inits::from_vector(vA));
- auto B1a = index_select(A, IndexVector({0}), 0); // always uses gather()
- auto B1b = step(A, 0, 0); // memory-consecutive view
- auto B2 = step(A, 0, 1); // not memory-consecutive
- auto B3 = step(A, 1, -1);
- auto B4a = index_select(A, IndexVector({0, 1}), 0);
- auto B4b = slice(A, Slice(0, 2), 0); // this is memory-consecutive
- auto B5 = slice(A, Slice(0, 4), 0); // this is a no-op
+ auto B1a = index_select(A, 0, IndexVector({0})); // always uses gather()
+ auto B1b = slice(A, 0, 0); // memory-consecutive view
+ auto B2 = slice(A, 1, 0); // not memory-consecutive
+ auto B3 = slice(A, -1, 1);
+ auto B4a = index_select(A, 0, IndexVector({0, 1}));
+ auto B4b = slice(A, 0, Slice(0, 2)); // this is memory-consecutive
+ auto B5 = slice(A, 0, Slice(0, 4)); // this is a no-op
CHECK(B1a->type() == "rows"); // actually optimized to rows()
CHECK(B1b->type() == "sliceView"); // must use view
CHECK(B2->type() == "gather"); // cannot use view
@@ -682,21 +682,21 @@ void tests(DeviceType device) {
CHECK(B5.get() == A.get()); // must be no-op
auto C = graph->param("2x3x2", {2, 3, 2}, inits::from_vector(vC));
- auto D1 = step(C, 0, 0);
- auto D2 = step(C, 2, -2);
- auto D3 = index_select(C, IndexVector({0, 2}), 1); // C[:,(0,2),:]
+ auto D1 = slice(C, 0, 0);
+ auto D2 = slice(C, -2, 2);
+ auto D3 = index_select(C, 1, IndexVector({0, 2})); // C[:,(0,2),:]
CHECK(D1->type() == "sliceView");
CHECK(D2->type() == "gather");
// enable this once gather() supports batched indices:
- //auto D4 = gather(C, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
- // inits::from_vector(std::vector<IndexType>{
- // 2, 1,
- // 0, 2 }),
- // Type::uint32), 1);
+ //auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
+ // inits::from_vector(std::vector<IndexType>{
+ // 2, 1,
+ // 0, 2 }),
+ // Type::uint32));
- auto S1 = step(A, 2, 0);
- auto S2 = narrow(A, 1, 2, 0);
- auto S3 = slice(A, Slice(-2, Slice::END), 0);
+ auto S1 = slice(A, 0, 2);
+ auto S2 = narrow(A, 0, 1, 2);
+ auto S3 = slice(A, 0, Slice(-2, Slice::END));
graph->forward();
@@ -727,9 +727,9 @@ void tests(DeviceType device) {
auto A = graph->param("4x3", {4, 3}, inits::from_vector(vA));
auto B1 = rows(A, indices);
- auto B2 = gather(A, graph->indices(indices, A, 0), 0);
+ auto B2 = gather(A, 0, graph->indices(indices, A, 0));
auto C1 = cols(A, indices);
- auto C2 = gather(A, graph->indices(indices, A, 1), 1);
+ auto C2 = gather(A, 1, graph->indices(indices, A, 1));
graph->forward();
CHECK(B1->shape() == B2->shape());