diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 01:38:18 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 01:39:24 +0300 |
commit | 6f920cb44b425c927c6a61b0b69c740c8ea1643f (patch) | |
tree | 3cd629932295afa2031549a3a0882f65faf10bde | |
parent | a44dd8607f81e652dc8310a80aa70f79bcdc97f2 (diff) |
Create generic CallbackRowMajorAccess class
-rw-r--r-- | tile/access.h | 28 | ||||
-rw-r--r-- | tile/multiply.h | 1 | ||||
-rw-r--r-- | tile/multiply.inl | 3 |
3 files changed, 24 insertions, 8 deletions
diff --git a/tile/access.h b/tile/access.h index 11f67ec..415eaa0 100644 --- a/tile/access.h +++ b/tile/access.h @@ -3,20 +3,21 @@ #include <type_traits> #include "../types.h" +#include "callbacks.h" namespace intgemm { // See also: RegisterRowMajorAccess is RowMajorAccess<Register> but without the // compiler warning. That is defined in dot.h. -template <class T> class RowMajorAccess { +template <class T, class Callback> class CallbackRowMajorAccess { public: typedef T Content; - RowMajorAccess(Content *data, Index cols) - : data_(data), cols_(cols) {} + CallbackRowMajorAccess(Content *data, Index cols, const typename Callback::Config& callback_config = {}) + : data_(data), cols_(cols), callback_config_(callback_config) {} - RowMajorAccess<Content> Add(Index row, Index col) const { - return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_); + CallbackRowMajorAccess<Content, Callback> Add(Index row, Index col) const { + return CallbackRowMajorAccess<Content, Callback>(data_ + row * cols_ + col, cols_, callback_config_); } Index Cols() const { return cols_; } @@ -24,11 +25,28 @@ template <class T> class RowMajorAccess { const Content &Front() const { return *data_; } Content &Front() { return *data_; } + template <Index A_rows, Index B_cols> + void Write(const __m128i *from) { + Callback::template Run<A_rows, B_cols>(*this, from, callback_config_); + } + template <Index A_rows, Index B_cols> + void Write(const __m256i *from) { + Callback::template Run<A_rows, B_cols>(*this, from, callback_config_); + } + template <Index A_rows, Index B_cols> + void Write(const __m512i *from) { + Callback::template Run<A_rows, B_cols>(*this, from, callback_config_); + } + private: Content *data_; Index cols_; + typename Callback::Config callback_config_; }; +template <class T> +using RowMajorAccess = CallbackRowMajorAccess<T, WriteCallback>; + template <class T> class ColMajorAccess { public: typedef T Content; diff --git a/tile/multiply.h b/tile/multiply.h index 2c0ca46..294b3ed 100644 --- a/tile/multiply.h +++ b/tile/multiply.h @@ -1,7 +1,6 @@ #pragma once #include "access.h" -#include "callbacks.h" #include "dot.h" #include "reduce.h" #include "../types.h" diff --git a/tile/multiply.inl b/tile/multiply.inl index 4d208d9..a1a92cf 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -51,8 +51,7 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>()); // Horizontally add 32-bit values. Reduce32<Outputs, Sum32Op>(c_regs); - // col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); - WriteCallback::template Run<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(col_row.CAccessor(), c_regs); + col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); } } } |