Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2021-10-15 15:21:56 +0300
committerGitHub <noreply@github.com>2021-10-15 15:21:56 +0300
commiteec100d5113a497bd61281a9698ceec4af2c41b8 (patch)
tree34b58569366709f744b795e44ad4142aeb43fa61 /tests
parentba8aefdb54b6fb2d88dac321f3b109303d4654fd (diff)
Apply LogSoftMax in-place during decoding (#584)
Diffstat (limited to 'tests')
-rw-r--r--tests/ops_test.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tests/ops_test.cc b/tests/ops_test.cc
index c621d287..943c3166 100644
--- a/tests/ops_test.cc
+++ b/tests/ops_test.cc
@@ -541,29 +541,33 @@ TEST_P(OpDeviceTest, TopKChangeK) {
TEST_P(OpDeviceFPTest, SoftMax) {
const Device device = GetParam().first;
const DataType dtype = GetParam().second;
- StorageView x({2, 5}, std::vector<float>{
+ StorageView x = StorageView({2, 5}, std::vector<float>{
-0.2, 3.0, 1.2, -1.1, 0.0,
- 4.6, 3.3, 0.2, -1.6, 1.0}, device);
+ 4.6, 3.3, 0.2, -1.6, 1.0}, device).to(dtype);
StorageView expected({2, 5}, std::vector<float>{
0.032035, 0.785904, 0.129909, 0.013025, 0.039128,
0.760941, 0.207381, 0.009342, 0.001544, 0.020792}, device);
StorageView y(dtype, device);
- ops::SoftMax()(x.to(dtype), y);
+ ops::SoftMax()(x, y);
expect_storage_eq(y.to_float(), expected, 1e-3);
+ ops::SoftMax()(x);
+ expect_storage_eq(x.to_float(), expected, 1e-3);
}
TEST_P(OpDeviceFPTest, LogSoftMax) {
const Device device = GetParam().first;
const DataType dtype = GetParam().second;
- StorageView x({2, 10}, std::vector<float>{
+ StorageView x = StorageView({2, 10}, std::vector<float>{
-0.2, 3.0, 1.2, -1.1, 0.0, 0.2, -3.0, -1.2, 1.1, 0.0,
- 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device);
+ 4.6, 3.3, 0.2, -1.6, 1.0, -4.6, -3.3, -0.2, 1.6, -1.0}, device).to(dtype);
StorageView expected({2, 10}, std::vector<float>{
-3.638294, -0.438294, -2.238294, -4.538294, -3.438294, -3.238294, -6.438294, -4.638294, -2.338294, -3.438294,
-0.319434, -1.619434, -4.719434, -6.519434, -3.919434, -9.519434, -8.219434, -5.119434, -3.319434, -5.919434}, device);
StorageView y(dtype, device);
- ops::LogSoftMax()(x.to(dtype), y);
+ ops::LogSoftMax()(x, y);
expect_storage_eq(y.to_float(), expected, 1e-2);
+ ops::LogSoftMax()(x);
+ expect_storage_eq(x.to_float(), expected, 1e-2);
}
TEST_P(OpDeviceFPTest, MaskedSoftMax) {