diff options
Diffstat (limited to 'src/tests/operator_tests.cpp')
-rwxr-xr-x | src/tests/operator_tests.cpp | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp index 09e10a6d..2607f41f 100755 --- a/src/tests/operator_tests.cpp +++ b/src/tests/operator_tests.cpp @@ -667,13 +667,13 @@ void tests(DeviceType device) { std::vector<float> vS3({7, -8, 9, -10, 11, -12}); auto A = graph->param("4x3", {4,3}, inits::from_vector(vA)); - auto B1a = index_select(A, IndexVector({0}), 0); // always uses gather() - auto B1b = step(A, 0, 0); // memory-consecutive view - auto B2 = step(A, 0, 1); // not memory-consecutive - auto B3 = step(A, 1, -1); - auto B4a = index_select(A, IndexVector({0, 1}), 0); - auto B4b = slice(A, Slice(0, 2), 0); // this is memory-consecutive - auto B5 = slice(A, Slice(0, 4), 0); // this is a no-op + auto B1a = index_select(A, 0, IndexVector({0})); // always uses gather() + auto B1b = slice(A, 0, 0); // memory-consecutive view + auto B2 = slice(A, 1, 0); // not memory-consecutive + auto B3 = slice(A, -1, 1); + auto B4a = index_select(A, 0, IndexVector({0, 1})); + auto B4b = slice(A, 0, Slice(0, 2)); // this is memory-consecutive + auto B5 = slice(A, 0, Slice(0, 4)); // this is a no-op CHECK(B1a->type() == "rows"); // actually optimized to rows() CHECK(B1b->type() == "sliceView"); // must use view CHECK(B2->type() == "gather"); // cannot use view @@ -682,21 +682,21 @@ void tests(DeviceType device) { CHECK(B5.get() == A.get()); // must be no-op auto C = graph->param("2x3x2", {2, 3, 2}, inits::from_vector(vC)); - auto D1 = step(C, 0, 0); - auto D2 = step(C, 2, -2); - auto D3 = index_select(C, IndexVector({0, 2}), 1); // C[:,(0,2),:] + auto D1 = slice(C, 0, 0); + auto D2 = slice(C, -2, 2); + auto D3 = index_select(C, 1, IndexVector({0, 2})); // C[:,(0,2),:] CHECK(D1->type() == "sliceView"); CHECK(D2->type() == "gather"); // enable this once gather() supports batched indices: - //auto D4 = gather(C, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]] - // inits::from_vector(std::vector<IndexType>{ - // 2, 1, - // 0, 2 }), - // Type::uint32), 1); + //auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]] + // inits::from_vector(std::vector<IndexType>{ + // 2, 1, + // 0, 2 }), + // Type::uint32)); - auto S1 = step(A, 2, 0); - auto S2 = narrow(A, 1, 2, 0); - auto S3 = slice(A, Slice(-2, Slice::END), 0); + auto S1 = slice(A, 0, 2); + auto S2 = narrow(A, 0, 1, 2); + auto S3 = slice(A, 0, Slice(-2, Slice::END)); graph->forward(); @@ -727,9 +727,9 @@ void tests(DeviceType device) { auto A = graph->param("4x3", {4, 3}, inits::from_vector(vA)); auto B1 = rows(A, indices); - auto B2 = gather(A, graph->indices(indices, A, 0), 0); + auto B2 = gather(A, 0, graph->indices(indices, A, 0)); auto C1 = cols(A, indices); - auto C2 = gather(A, graph->indices(indices, A, 1), 1); + auto C2 = gather(A, 1, graph->indices(indices, A, 1)); graph->forward(); CHECK(B1->shape() == B2->shape()); |