Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-05-08 18:14:56 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-05-08 18:14:56 +0300
commitf9c11df2cef68e333c0a853acea8c7e8bab37d0c (patch)
tree70b20a4ad73ab576651640532c70e96c9d516bd2
parent81bfb8b293a3773653ea0f2fdeb372ef583501d7 (diff)
Add support for writing floats in Access::Write
-rw-r--r--tile/access.h6
1 files changed, 4 insertions, 2 deletions
diff --git a/tile/access.h b/tile/access.h
index 5faf9a1..f8ae709 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -77,7 +77,8 @@ template <class T> class RowMajorAccess {
// The offsets add B_cols - ColRemain so they can be correct modulo the number of columns.
// So we subtract that from the data pointer.
int32_t *go_back = data_ - (B_cols - ColRemain);
- _mm512_mask_i32scatter_epi32(go_back, mask, offsets, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)), sizeof(int32_t));
+ auto result = reinterpret_cast<__m512i>(callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)));
+ _mm512_mask_i32scatter_epi32(go_back, mask, offsets, result, sizeof(int32_t));
// We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc.
// 16 - ColRemain of the next row and whatever followed.
constexpr Index Wrote = ((remaining < 16) ? remaining : 16);
@@ -97,7 +98,8 @@ template <class T> class RowMajorAccess {
template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type
WriteImpl(const __m512i *from, CallbackImpl& callback_impl) {
- _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)));
+ auto result = reinterpret_cast<__m512i>(callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)));
+ _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, result);
}
// Nothing to write.