diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-19 14:02:39 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-19 14:02:39 +0300 |
commit | a3a6a9b845ed5dc51e81e7cd8b9b9ba84855edaa (patch) | |
tree | 80e436b907af14d8c94552e92b9fd993cf3e2448 | |
parent | 51011c1a6be683fdf761c1d365d27c57c44d99f3 (diff) |
Change to integer sequence for unrolling kernels
-rw-r--r-- | test/tile_test.inl | 18 | ||||
-rw-r--r-- | tile/dot.inl | 48 | ||||
-rw-r--r-- | tile/multiply.inl | 17 | ||||
-rw-r--r-- | types.h | 3 | ||||
-rw-r--r-- | utils.h | 24 |
5 files changed, 61 insertions, 49 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl index 0968bef..18263b1 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -44,7 +44,7 @@ INTGEMM_TARGET void OneIteration() { InputA(A.begin(), sizeof(Register)), InputB(B.begin(), sizeof(Register)), Output(reinterpret_cast<Register*>(C.begin()), 1)); - MatrixTile<1, 1, Shifted8>::Run(access); + UnrollKernel<1, 1, 1, Shifted8>::Run(access); const std::size_t kStride = sizeof(int32_t) / sizeof(int8_t); for (std::size_t i = 0; i < sizeof(Register) / sizeof(int32_t); ++i) { @@ -218,14 +218,16 @@ TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") { TestMultiplyNoOverhangShapes<Signed8>(); } -TEST_CASE("MultiplyNoOverhang Tiled Signed8 " INTGEMM_TEST_NAME, "[tile]") { +TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[tile]") { if (kCPU < CPUType::INTGEMM_ARCH) return; - TestMultiplyNoOverhangShapes<InnerTile<1, Signed8> >(); - TestMultiplyNoOverhangShapes<InnerTile<2, Signed8> >(); - TestMultiplyNoOverhangShapes<InnerTile<3, Signed8> >(); - TestMultiplyNoOverhangShapes<MatrixTile<3, 3, Signed8> >(); - TestMultiplyNoOverhangShapes<InnerTile<2, MatrixTile<3, 3, Signed8> > >(); - TestMultiplyNoOverhangShapes<MatrixTile<4, 4, InnerTile<3, Signed8> > >(); + TestMultiplyNoOverhangShapes<UnrollKernel<1, 1, 1, Signed8> >(); + + TestMultiplyNoOverhangShapes<UnrollKernel<2, 1, 1, Signed8> >(); + TestMultiplyNoOverhangShapes<UnrollKernel<1, 2, 1, Signed8> >(); + TestMultiplyNoOverhangShapes<UnrollKernel<1, 1, 2, Signed8> >(); + + TestMultiplyNoOverhangShapes<UnrollKernel<2, 2, 2, Signed8> >(); + TestMultiplyNoOverhangShapes<UnrollKernel<4, 4, 3, Signed8> >(); } #endif diff --git a/tile/dot.inl b/tile/dot.inl index adbac37..9af7b1b 100644 --- a/tile/dot.inl +++ b/tile/dot.inl @@ -132,41 +132,31 @@ struct Signed16 { }; }; -/* These would normally be outside arch namespaces but gcc refuses to inline - * functions with target attributes into functions without target attributes. - * So they all need target attributes. */ -template <class Backend> struct MatrixTileBody { - template <typename Iterator, typename Access> INTGEMM_TARGET static inline void body(Access access) { - Backend::Run(access - .AAdd(Iterator::template I<0>() * Backend::kTile.A_rows, 0) - .BAdd(0, Iterator::template I<1>() * Backend::kTile.B_cols) - .CAdd(Iterator::template I<0>() * Backend::kTile.A_rows, Iterator::template I<1>() * Backend::kTile.B_cols)); - } -}; -// Wrap a tile to statically loop over rows of A and columns of B. -template <Index A_rows, Index B_cols, class Backend> struct MatrixTile { +// Unroll an arbitrary amount of +// Can't have Tile as a value until C++20. +template <Index A_rows, Index inner, Index B_cols, class Backend> struct UnrollKernel { template <class Access> INTGEMM_TARGET __attribute__((flatten)) static inline void Run(Access access) { - StaticLoop<MatrixTileBody<Backend>, MakeStaticLoopIterator<A_rows, B_cols>>(access); + body(access, make_index_sequence<A_rows * inner * B_cols>()); } - static constexpr Tile kTile { A_rows * Backend::kTile.A_rows, Backend::kTile.inner, B_cols * Backend::kTile.B_cols }; + static constexpr Tile kTile { A_rows * Backend::kTile.A_rows, inner * Backend::kTile.inner, B_cols * Backend::kTile.B_cols }; typedef typename Backend::Packed Packed; -}; -template <class Backend> struct InnerTileBody { - template <typename Iterator, typename Access> INTGEMM_TARGET static inline void body(Access access) { - Backend::Run(access - .AAdd(0, Iterator::template I<0>() * Backend::kTile.inner) - .BAdd(Iterator::template I<0>() * Backend::kTile.inner, 0)); + private: + template <class Access, size_t... Index> INTGEMM_TARGET __attribute__((flatten)) static inline void body( + Access access, + index_sequence<Index...>) { + // for each inner computed as (Index / A_rows / B_cols) + // for each A_row computed as (Index % (A_rows * B_cols)) / B_cols + // for each B_col computed as (Index % B_cols) + unordered_unfurl(( + Backend::Run(access + .AAdd((Index % (A_rows * B_cols)) / B_cols * Backend::kTile.A_rows, (Index / A_rows / B_cols) * Backend::kTile.inner) + .BAdd((Index / A_rows / B_cols) * Backend::kTile.inner, (Index % B_cols) * Backend::kTile.B_cols) + .CAdd((Index % (A_rows * B_cols)) / B_cols * Backend::kTile.A_rows, (Index % B_cols) * Backend::kTile.B_cols)) + // Backend returns void, so use a tuple to make 0. + , 0)...); } }; -// Wrap a tile to statically loop over the inner dimension. -template <Index inner, class Backend> struct InnerTile { - template <class Access> INTGEMM_TARGET __attribute__((flatten)) static inline void Run(Access access) { - StaticLoop<InnerTileBody<Backend>, MakeStaticLoopIterator<inner>>(access); - } - static constexpr Tile kTile { Backend::kTile.A_rows, inner * Backend::kTile.inner, Backend::kTile.B_cols }; - typedef typename Backend::Packed Packed; -}; } // namespace INTGEMM_ARCH } // namespace diff --git a/tile/multiply.inl b/tile/multiply.inl index f1a22cf..3bec91f 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -19,20 +19,19 @@ namespace intgemm { namespace INTGEMM_ARCH { // Upcast 16 to 32 if needed. -template <class T> INTGEMM_TARGET static inline void SumTo32(Register ®); -template <> INTGEMM_TARGET inline void SumTo32<int16_t>(Register ®) { - reg = madd_epi16(reg, set1_epi16<Register>(1)); -} -template <> INTGEMM_TARGET inline void SumTo32<int32_t>(Register &) {} - -template <class T> struct SumTo32Body { +template <class T> struct SumTo32Body; +template <> struct SumTo32Body<int16_t> { template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { - SumTo32<T>(regs[Iterator::template I<0>()]); + Register ® = regs[Iterator::template I<0>()]; + reg = madd_epi16(reg, set1_epi16<Register>(1)); } }; +template <> struct SumTo32Body<int32_t> { + template <class Iterator> INTGEMM_TARGET static inline void body(Register *) {} +}; /* Multiply assuming the matrix sizes are a multiple of the kernel size. */ -template <class AccessT, class Kernel> INTGEMM_TARGET static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { +template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); assert(shape.B_cols % Kernel::kTile.B_cols == 0); @@ -49,9 +49,6 @@ enum class CPUType { extern const CPUType kCPU; struct Tile { -/* Tile() {} - Tile(Index in_A_rows, Index in_inner, Index in_B_cols) - : A_rows(in_A_rows), inner(in_inner), B_cols(in_B_cols) {} */ Index A_rows, inner, B_cols; }; @@ -1,9 +1,33 @@ #pragma once +#include "types.h" #include <tuple> namespace intgemm { +// Function to absorb arguments from integer sequences. +template<typename... Args> void unordered_unfurl(Args&&...) {} + +// C++11 implementation of C++14's make_index_sequence. +// This is a bugfix from a stackoverflow post that did [0, N] while the standard does [0, N). +// https://stackoverflow.com/questions/52844615/is-that-possible-to-have-a-for-loop-in-compile-time-with-runtime-or-even-compile +template <size_t... Is> +struct index_sequence{}; + +namespace detail { + template <size_t I,size_t...Is> + struct make_index_sequence_impl : make_index_sequence_impl<I-1,I-1,Is...> {}; + + template <size_t...Is> + struct make_index_sequence_impl<0,Is...> + { + using type = index_sequence<Is...>; + }; +} + +template<size_t N> +using make_index_sequence = typename detail::make_index_sequence_impl<N>::type; + /* * Sequence of unsigned integers * |