Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 21:48:51 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 22:53:05 +0300
commit6de7be50a5ba393512206a41b58634e418492bba (patch)
treed25d690ba85987a2c6fa5a32e444878f77a5a1ce
parent545275446d37c701aa29f6a01978896f0320c401 (diff)
Fix floor_ff kernel for SSE2
-rw-r--r--kernels/implementations.inl17
-rw-r--r--test/kernels/floor_ff_test.cc10
2 files changed, 18 insertions, 9 deletions
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index 504b650..4a84484 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -152,11 +152,20 @@ CPU_ATTR static inline dvd relu(dvd input) {
/*
* Calculate floor: float -> float
*/
-CPU_ATTR static inline vf floor_ff(vf a) {
-#if defined(THIS_IS_AVX2)
- return _mm256_floor_ps(a);
+CPU_ATTR static inline vf floor_ff(vf input) {
+#if defined(THIS_IS_SSE2)
+ static const auto vconst_zero = setzero_ps<vf>();
+ static const auto vconst_one = set1_ps<vf>(1.f);
+
+ auto result = cvtepi32_ps(cvttps_epi32(input));
+ auto negatives = _mm_cmplt_ps(input, vconst_zero);
+ auto nonintegers = _mm_cmpneq_ps(input, result);
+
+ return sub_ps(result, and_ps(vconst_one, and_ps(negatives, nonintegers)));
+#elif defined(THIS_IS_AVX2)
+ return _mm256_floor_ps(input);
#else
- return cvtepi32_ps(cvttps_epi32(a)); // TODO: Doesn't work for negative numbers
+ assert(false && "AVX512BW is not supported");
#endif
}
diff --git a/test/kernels/floor_ff_test.cc b/test/kernels/floor_ff_test.cc
index 748fb5f..674e35b 100644
--- a/test/kernels/floor_ff_test.cc
+++ b/test/kernels/floor_ff_test.cc
@@ -21,7 +21,7 @@ void kernel_floor_ff_test() {
*output.template as<vec_t>() = kernels::floor_ff(*input.template as<vec_t>());
for (auto i = 0; i < output.size(); ++i)
- CHECK(output[i] == i - 2);
+ CHECK(output[i] == std::floor(input[i]));
}
template INTGEMM_SSE2 void kernel_floor_ff_test<CPUType::SSE2>();
@@ -30,9 +30,9 @@ KERNEL_TEST_CASE("floor_ff SSE2") { return kernel_floor_ff_test<CPUType::SSE2>()
template INTGEMM_AVX2 void kernel_floor_ff_test<CPUType::AVX2>();
KERNEL_TEST_CASE("floor_ff AVX2") { return kernel_floor_ff_test<CPUType::AVX2>(); }
-#ifndef INTGEMM_NO_AVX512
-template INTGEMM_AVX512BW void kernel_floor_ff_test<CPUType::AVX512BW>();
-KERNEL_TEST_CASE("floor_ff AVX512BW") { return kernel_floor_ff_test<CPUType::AVX512BW>(); }
-#endif
+// #ifndef INTGEMM_NO_AVX512
+// template INTGEMM_AVX512BW void kernel_floor_ff_test<CPUType::AVX512BW>();
+// KERNEL_TEST_CASE("floor_ff AVX512BW") { return kernel_floor_ff_test<CPUType::AVX512BW>(); }
+// #endif
}