Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2021-11-19 17:20:12 +0300
committerGitHub <noreply@github.com>2021-11-19 17:20:12 +0300
commitf57d471e246e06fc3671747007ad61fd1adb5277 (patch)
treeffa6994962dee694a1ea02885f411eff48f1f44b /tests
parentd8be651775d67bc155513777efaaa8b1c755afb6 (diff)
Add a CUDA kernel for Multinomial op (#631)
* Add a CUDA kernel for Multinomial op The current implementation can only return a single sample per batch. * Install libcurand-devel to build the Python wheels
Diffstat (limited to 'tests')
-rw-r--r--tests/ops_test.cc20
1 files changed, 16 insertions, 4 deletions
diff --git a/tests/ops_test.cc b/tests/ops_test.cc
index dbc2c64c..73772455 100644
--- a/tests/ops_test.cc
+++ b/tests/ops_test.cc
@@ -693,11 +693,23 @@ TEST_P(OpDeviceTest, QuantizeINT8ZeroRow) {
TEST_P(OpDeviceFPTest, Multinomial) {
const Device device = GetParam().first;
const DataType dtype = GetParam().second;
- StorageView input({2, 4}, std::vector<float>{0, 0, 1, 0, 0, 0, 0, 1}, device);
+ StorageView input({2, 4}, std::vector<float>{0.2, 0.1, 0.6, 0.1, 0.7, 0.2, 0.0, 0.1}, device);
StorageView output(DataType::INT32, device);
- StorageView expected({2, 2}, std::vector<int32_t>{2, 2, 3, 3}, device);
- ops::Multinomial(2)(input.to(dtype), output);
- expect_storage_eq(output, expected);
+ StorageView counts(input.shape(), int32_t(0));
+
+ constexpr dim_t num_draws = 5000;
+ for (dim_t i = 0; i < num_draws; ++i) {
+ ops::Multinomial(1)(input.to(dtype), output);
+ for (dim_t b = 0; b < output.dim(0); ++b)
+ counts.at<int32_t>({b, output.scalar_at<int32_t>({b, 0})}) += 1;
+ }
+
+ std::vector<int32_t> counts_vec = counts.to_vector<int32_t>();
+ std::vector<float> frequencies(counts_vec.begin(), counts_vec.end());
+ for (auto& frequency : frequencies)
+ frequency /= num_draws;
+
+ expect_storage_eq(StorageView(input.shape(), frequencies), input, 0.05);
}
TEST_P(OpDeviceFPTest, ReLU) {