Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tests/operator_tests.cpp')
-rwxr-xr-xsrc/tests/operator_tests.cpp40
1 files changed, 20 insertions, 20 deletions
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp
index 09e10a6d..2607f41f 100755
--- a/src/tests/operator_tests.cpp
+++ b/src/tests/operator_tests.cpp
@@ -667,13 +667,13 @@ void tests(DeviceType device) {
std::vector<float> vS3({7, -8, 9, -10, 11, -12});
auto A = graph->param("4x3", {4,3}, inits::from_vector(vA));
- auto B1a = index_select(A, IndexVector({0}), 0); // always uses gather()
- auto B1b = step(A, 0, 0); // memory-consecutive view
- auto B2 = step(A, 0, 1); // not memory-consecutive
- auto B3 = step(A, 1, -1);
- auto B4a = index_select(A, IndexVector({0, 1}), 0);
- auto B4b = slice(A, Slice(0, 2), 0); // this is memory-consecutive
- auto B5 = slice(A, Slice(0, 4), 0); // this is a no-op
+ auto B1a = index_select(A, 0, IndexVector({0})); // always uses gather()
+ auto B1b = slice(A, 0, 0); // memory-consecutive view
+ auto B2 = slice(A, 1, 0); // not memory-consecutive
+ auto B3 = slice(A, -1, 1);
+ auto B4a = index_select(A, 0, IndexVector({0, 1}));
+ auto B4b = slice(A, 0, Slice(0, 2)); // this is memory-consecutive
+ auto B5 = slice(A, 0, Slice(0, 4)); // this is a no-op
CHECK(B1a->type() == "rows"); // actually optimized to rows()
CHECK(B1b->type() == "sliceView"); // must use view
CHECK(B2->type() == "gather"); // cannot use view
@@ -682,21 +682,21 @@ void tests(DeviceType device) {
CHECK(B5.get() == A.get()); // must be no-op
auto C = graph->param("2x3x2", {2, 3, 2}, inits::from_vector(vC));
- auto D1 = step(C, 0, 0);
- auto D2 = step(C, 2, -2);
- auto D3 = index_select(C, IndexVector({0, 2}), 1); // C[:,(0,2),:]
+ auto D1 = slice(C, 0, 0);
+ auto D2 = slice(C, -2, 2);
+ auto D3 = index_select(C, 1, IndexVector({0, 2})); // C[:,(0,2),:]
CHECK(D1->type() == "sliceView");
CHECK(D2->type() == "gather");
// enable this once gather() supports batched indices:
- //auto D4 = gather(C, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
- // inits::from_vector(std::vector<IndexType>{
- // 2, 1,
- // 0, 2 }),
- // Type::uint32), 1);
+ //auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
+ // inits::from_vector(std::vector<IndexType>{
+ // 2, 1,
+ // 0, 2 }),
+ // Type::uint32));
- auto S1 = step(A, 2, 0);
- auto S2 = narrow(A, 1, 2, 0);
- auto S3 = slice(A, Slice(-2, Slice::END), 0);
+ auto S1 = slice(A, 0, 2);
+ auto S2 = narrow(A, 0, 1, 2);
+ auto S3 = slice(A, 0, Slice(-2, Slice::END));
graph->forward();
@@ -727,9 +727,9 @@ void tests(DeviceType device) {
auto A = graph->param("4x3", {4, 3}, inits::from_vector(vA));
auto B1 = rows(A, indices);
- auto B2 = gather(A, graph->indices(indices, A, 0), 0);
+ auto B2 = gather(A, 0, graph->indices(indices, A, 0));
auto C1 = cols(A, indices);
- auto C2 = gather(A, graph->indices(indices, A, 1), 1);
+ auto C2 = gather(A, 1, graph->indices(indices, A, 1));
graph->forward();
CHECK(B1->shape() == B2->shape());