diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2021-10-15 15:21:56 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-15 15:21:56 +0300 |
commit | eec100d5113a497bd61281a9698ceec4af2c41b8 (patch) | |
tree | 34b58569366709f744b795e44ad4142aeb43fa61 /tests | |
parent | ba8aefdb54b6fb2d88dac321f3b109303d4654fd (diff) |
Apply LogSoftMax in-place during decoding (#584)
Diffstat (limited to 'tests')
-rw-r--r-- | tests/ops_test.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tests/ops_test.cc b/tests/ops_test.cc index c621d287..943c3166 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -541,29 +541,33 @@ TEST_P(OpDeviceTest, TopKChangeK) { TEST_P(OpDeviceFPTest, SoftMax) { const Device device = GetParam().first; const DataType dtype = GetParam().second; - StorageView x({2, 5}, std::vector<float>{ + StorageView x = StorageView({2, 5}, std::vector<float>{ -0.2, 3.0, 1.2, -1.1, 0.0, - 4.6, 3.3, 0.2, -1.6, 1.0}, device); + 4.6, 3.3, 0.2, -1.6, 1.0}, device).to(dtype); StorageView expected({2, 5}, std::vector<float>{ 0.032035, 0.785904, 0.129909, 0.013025, 0.039128, 0.760941, 0.207381, 0.009342, 0.001544, 0.020792}, device); StorageView y(dtype, device); - ops::SoftMax()(x.to(dtype), y); + ops::SoftMax()(x, y); expect_storage_eq(y.to_float(), expected, 1e-3); + ops::SoftMax()(x); + expect_storage_eq(x.to_float(), expected, 1e-3); } TEST_P(OpDeviceFPTest, LogSoftMax) { const Device device = GetParam().first; const DataType dtype = GetParam().second; - StorageView x({2, 10}, std::vector<float>{ + StorageView x = StorageView({2, 10}, std::vector<float>{ -0.2, 3.0, 1.2, -1.1, 0.0, 0.2, -3.0, -1.2, 1.1, 0.0, - 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device); + 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device).to(dtype); StorageView expected({2, 10}, std::vector<float>{ -3.638294, -0.438294, -2.238294, -4.538294, -3.438294, -3.238294, -6.438294, -4.638294, -2.338294, -3.438294, -0.319434, -1.619434, -4.719434, -6.519434, -3.919434, -9.519434, -8.219434, -5.119434, -3.319434, -5.919434}, device); StorageView y(dtype, device); - ops::LogSoftMax()(x.to(dtype), y); + ops::LogSoftMax()(x, y); expect_storage_eq(y.to_float(), expected, 1e-2); + ops::LogSoftMax()(x); + expect_storage_eq(x.to_float(), expected, 1e-2); } TEST_P(OpDeviceFPTest, MaskedSoftMax) { |