diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2021-07-30 11:00:06 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-30 11:00:06 +0300 |
commit | 320b19825935f16e4bd2dd3f54a0a53a3897e489 (patch) | |
tree | 67b4ac62591e79184d5e5e7567b02cd9188b1cd9 /tests | |
parent | d8e3af32cad26de72f9245090c6b88f8aef20d87 (diff) |
Remove time dimension when decoding step by step (#530)
Diffstat (limited to 'tests')
-rw-r--r-- | tests/storage_view_test.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tests/storage_view_test.cc b/tests/storage_view_test.cc index ba017b9a..96fa2c0a 100644 --- a/tests/storage_view_test.cc +++ b/tests/storage_view_test.cc @@ -38,6 +38,28 @@ TEST(StorageViewTest, Reshape) { assert_vector_eq(a.shape(), Shape{16}); } +TEST(StorageViewTest, ExpandDimsAndSqueeze) { + { + StorageView a(Shape{4}); + a.expand_dims(0); + assert_vector_eq(a.shape(), Shape{1, 4}); + a.expand_dims(-1); + assert_vector_eq(a.shape(), Shape{1, 4, 1}); + a.squeeze(0); + assert_vector_eq(a.shape(), Shape{4, 1}); + a.squeeze(1); + assert_vector_eq(a.shape(), Shape{4}); + } + + { + StorageView a(Shape{4, 2}); + a.expand_dims(1); + assert_vector_eq(a.shape(), Shape{4, 1, 2}); + a.expand_dims(3); + assert_vector_eq(a.shape(), Shape{4, 1, 2, 1}); + } +} + class StorageViewDeviceTest : public ::testing::TestWithParam<Device> { }; |