Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-25 01:38:18 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-25 01:39:24 +0300
commit6f920cb44b425c927c6a61b0b69c740c8ea1643f (patch)
tree3cd629932295afa2031549a3a0882f65faf10bde
parenta44dd8607f81e652dc8310a80aa70f79bcdc97f2 (diff)
Create generic CallbackRowMajorAccess class
-rw-r--r--tile/access.h28
-rw-r--r--tile/multiply.h1
-rw-r--r--tile/multiply.inl3
3 files changed, 24 insertions, 8 deletions
diff --git a/tile/access.h b/tile/access.h
index 11f67ec..415eaa0 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -3,20 +3,21 @@
#include <type_traits>
#include "../types.h"
+#include "callbacks.h"
namespace intgemm {
// See also: RegisterRowMajorAccess is RowMajorAccess<Register> but without the
// compiler warning. That is defined in dot.h.
-template <class T> class RowMajorAccess {
+template <class T, class Callback> class CallbackRowMajorAccess {
public:
typedef T Content;
- RowMajorAccess(Content *data, Index cols)
- : data_(data), cols_(cols) {}
+ CallbackRowMajorAccess(Content *data, Index cols, const typename Callback::Config& callback_config = {})
+ : data_(data), cols_(cols), callback_config_(callback_config) {}
- RowMajorAccess<Content> Add(Index row, Index col) const {
- return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_);
+ CallbackRowMajorAccess<Content, Callback> Add(Index row, Index col) const {
+ return CallbackRowMajorAccess<Content, Callback>(data_ + row * cols_ + col, cols_, callback_config_);
}
Index Cols() const { return cols_; }
@@ -24,11 +25,28 @@ template <class T> class RowMajorAccess {
const Content &Front() const { return *data_; }
Content &Front() { return *data_; }
+ template <Index A_rows, Index B_cols>
+ void Write(const __m128i *from) {
+ Callback::template Run<A_rows, B_cols>(*this, from, callback_config_);
+ }
+ template <Index A_rows, Index B_cols>
+ void Write(const __m256i *from) {
+ Callback::template Run<A_rows, B_cols>(*this, from, callback_config_);
+ }
+ template <Index A_rows, Index B_cols>
+ void Write(const __m512i *from) {
+ Callback::template Run<A_rows, B_cols>(*this, from, callback_config_);
+ }
+
private:
Content *data_;
Index cols_;
+ typename Callback::Config callback_config_;
};
+template <class T>
+using RowMajorAccess = CallbackRowMajorAccess<T, WriteCallback>;
+
template <class T> class ColMajorAccess {
public:
typedef T Content;
diff --git a/tile/multiply.h b/tile/multiply.h
index 2c0ca46..294b3ed 100644
--- a/tile/multiply.h
+++ b/tile/multiply.h
@@ -1,7 +1,6 @@
#pragma once
#include "access.h"
-#include "callbacks.h"
#include "dot.h"
#include "reduce.h"
#include "../types.h"
diff --git a/tile/multiply.inl b/tile/multiply.inl
index 4d208d9..a1a92cf 100644
--- a/tile/multiply.inl
+++ b/tile/multiply.inl
@@ -51,8 +51,7 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>());
// Horizontally add 32-bit values.
Reduce32<Outputs, Sum32Op>(c_regs);
- // col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);
- WriteCallback::template Run<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(col_row.CAccessor(), c_regs);
+ col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);
}
}
}