Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2021-07-30 11:00:06 +0300
committerGitHub <noreply@github.com>2021-07-30 11:00:06 +0300
commit320b19825935f16e4bd2dd3f54a0a53a3897e489 (patch)
tree67b4ac62591e79184d5e5e7567b02cd9188b1cd9 /tests
parentd8e3af32cad26de72f9245090c6b88f8aef20d87 (diff)
Remove time dimension when decoding step by step (#530)
Diffstat (limited to 'tests')
-rw-r--r--tests/storage_view_test.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tests/storage_view_test.cc b/tests/storage_view_test.cc
index ba017b9a..96fa2c0a 100644
--- a/tests/storage_view_test.cc
+++ b/tests/storage_view_test.cc
@@ -38,6 +38,28 @@ TEST(StorageViewTest, Reshape) {
assert_vector_eq(a.shape(), Shape{16});
}
+TEST(StorageViewTest, ExpandDimsAndSqueeze) {
+ {
+ StorageView a(Shape{4});
+ a.expand_dims(0);
+ assert_vector_eq(a.shape(), Shape{1, 4});
+ a.expand_dims(-1);
+ assert_vector_eq(a.shape(), Shape{1, 4, 1});
+ a.squeeze(0);
+ assert_vector_eq(a.shape(), Shape{4, 1});
+ a.squeeze(1);
+ assert_vector_eq(a.shape(), Shape{4});
+ }
+
+ {
+ StorageView a(Shape{4, 2});
+ a.expand_dims(1);
+ assert_vector_eq(a.shape(), Shape{4, 1, 2});
+ a.expand_dims(3);
+ assert_vector_eq(a.shape(), Shape{4, 1, 2, 1});
+ }
+}
+
class StorageViewDeviceTest : public ::testing::TestWithParam<Device> {
};