diff options
Diffstat (limited to 'tile/dot.inl')
-rw-r--r-- | tile/dot.inl | 48 |
1 files changed, 19 insertions, 29 deletions
diff --git a/tile/dot.inl b/tile/dot.inl index adbac37..9af7b1b 100644 --- a/tile/dot.inl +++ b/tile/dot.inl @@ -132,41 +132,31 @@ struct Signed16 { }; }; -/* These would normally be outside arch namespaces but gcc refuses to inline - * functions with target attributes into functions without target attributes. - * So they all need target attributes. */ -template <class Backend> struct MatrixTileBody { - template <typename Iterator, typename Access> INTGEMM_TARGET static inline void body(Access access) { - Backend::Run(access - .AAdd(Iterator::template I<0>() * Backend::kTile.A_rows, 0) - .BAdd(0, Iterator::template I<1>() * Backend::kTile.B_cols) - .CAdd(Iterator::template I<0>() * Backend::kTile.A_rows, Iterator::template I<1>() * Backend::kTile.B_cols)); - } -}; -// Wrap a tile to statically loop over rows of A and columns of B. -template <Index A_rows, Index B_cols, class Backend> struct MatrixTile { +// Unroll an arbitrary amount of +// Can't have Tile as a value until C++20. +template <Index A_rows, Index inner, Index B_cols, class Backend> struct UnrollKernel { template <class Access> INTGEMM_TARGET __attribute__((flatten)) static inline void Run(Access access) { - StaticLoop<MatrixTileBody<Backend>, MakeStaticLoopIterator<A_rows, B_cols>>(access); + body(access, make_index_sequence<A_rows * inner * B_cols>()); } - static constexpr Tile kTile { A_rows * Backend::kTile.A_rows, Backend::kTile.inner, B_cols * Backend::kTile.B_cols }; + static constexpr Tile kTile { A_rows * Backend::kTile.A_rows, inner * Backend::kTile.inner, B_cols * Backend::kTile.B_cols }; typedef typename Backend::Packed Packed; -}; -template <class Backend> struct InnerTileBody { - template <typename Iterator, typename Access> INTGEMM_TARGET static inline void body(Access access) { - Backend::Run(access - .AAdd(0, Iterator::template I<0>() * Backend::kTile.inner) - .BAdd(Iterator::template I<0>() * Backend::kTile.inner, 0)); + private: + template <class Access, size_t... Index> INTGEMM_TARGET __attribute__((flatten)) static inline void body( + Access access, + index_sequence<Index...>) { + // for each inner computed as (Index / A_rows / B_cols) + // for each A_row computed as (Index % (A_rows * B_cols)) / B_cols + // for each B_col computed as (Index % B_cols) + unordered_unfurl(( + Backend::Run(access + .AAdd((Index % (A_rows * B_cols)) / B_cols * Backend::kTile.A_rows, (Index / A_rows / B_cols) * Backend::kTile.inner) + .BAdd((Index / A_rows / B_cols) * Backend::kTile.inner, (Index % B_cols) * Backend::kTile.B_cols) + .CAdd((Index % (A_rows * B_cols)) / B_cols * Backend::kTile.A_rows, (Index % B_cols) * Backend::kTile.B_cols)) + // Backend returns void, so use a tuple to make 0. + , 0)...); } }; -// Wrap a tile to statically loop over the inner dimension. -template <Index inner, class Backend> struct InnerTile { - template <class Access> INTGEMM_TARGET __attribute__((flatten)) static inline void Run(Access access) { - StaticLoop<InnerTileBody<Backend>, MakeStaticLoopIterator<inner>>(access); - } - static constexpr Tile kTile { Backend::kTile.A_rows, inner * Backend::kTile.inner, Backend::kTile.B_cols }; - typedef typename Backend::Packed Packed; -}; } // namespace INTGEMM_ARCH } // namespace |