diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-30 21:09:51 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-30 21:13:54 +0300 |
commit | 1280a3715a9338caddc29f04176fe66daa79895e (patch) | |
tree | 975f64ea7af7c5dbd91efcecb3ff350d0a1b3fcc | |
parent | e0829175a6c72ab96f4c1c44379c444978a9d57b (diff) |
Remove need of passing template parameter which can be deduced by a compilerstatic
-rw-r--r-- | benchmarks/benchmark_tile.cc | 4 | ||||
-rw-r--r-- | test/tile_test.inl | 8 | ||||
-rw-r--r-- | tile/multiply.inl | 10 |
3 files changed, 11 insertions, 11 deletions
diff --git a/benchmarks/benchmark_tile.cc b/benchmarks/benchmark_tile.cc index 05169c1..f6ff0b7 100644 --- a/benchmarks/benchmark_tile.cc +++ b/benchmarks/benchmark_tile.cc @@ -93,9 +93,9 @@ template <Index A_rows, Index B_cols> static inline double BenchmarkNoOverhang(A typedef AVX512VNNI::UnrollKernel<A_rows, 1, B_cols, AVX512VNNI::Shifted8> Kernel; // Burn in. // TODO: different arches, guard against old compilers, etc. - AVX512VNNI::MultiplyNoOverhang<Accessor, Kernel>(access, shape); + AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape); for (std::size_t t = 0; t < kTries; ++t) { - AVX512VNNI::MultiplyNoOverhang<Accessor, Kernel>(access, shape); + AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape); } auto end = std::chrono::steady_clock::now(); return std::chrono::duration<double>(end - start).count() / kTries; diff --git a/test/tile_test.inl b/test/tile_test.inl index 267870d..865e922 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -160,7 +160,7 @@ struct TestMatricesRef : TestMatrices8 { RowMajorAccess<int8_t>(A.begin(), shape.inner), ColMajorAccess<int8_t>(B.begin(), shape.inner), RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols)); - Signed8ReferenceMult<AccessT>(ref_access, shape); + Signed8ReferenceMult(ref_access, shape); } AlignedVector<int32_t> C_reference; @@ -173,7 +173,7 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { CHECK(shape.inner % Kernel::kTile.inner == 0); CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); TestMatricesRef t(shape); - MultiplyNoOverhang<TestMatricesRef::AccessT, Kernel>(t.Accessor(), shape); + MultiplyNoOverhang<Kernel>(t.Accessor(), shape); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); /* for (Index i = 0; i < shape.A_rows; ++i) { for (Index j = 0; j < shape.B_cols; ++j) { @@ -286,10 +286,10 @@ TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") { for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) { for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) { TestMatricesRef t(shape); - Multiply<TestMatricesRef::AccessT, Signed8, 7, 3>(t.Accessor(), shape); + Multiply<Signed8, 7, 3>(t.Accessor(), shape); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); memset(t.C.begin(), 0, shape.A_rows * shape.B_cols * sizeof(int32_t)); - Multiply<TestMatricesRef::AccessT, Signed8, 4, 5>(t.Accessor(), shape); + Multiply<Signed8, 4, 5>(t.Accessor(), shape); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); } } diff --git a/tile/multiply.inl b/tile/multiply.inl index a1a92cf..be86255 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -25,7 +25,7 @@ template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *, int32_t, index_sequence<i...>) {} /* Multiply assuming the matrix sizes are a multiple of the kernel size. */ -template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { +template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); assert(shape.B_cols % Kernel::kTile.B_cols == 0); @@ -64,7 +64,7 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s * A_rows and B_cols specify the unrolled kernel size to use for most of the * multiply; these impact speed but not output. */ -template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape) { +template <class Kernel, Index A_rows, Index B_cols, class AccessT> INTGEMM_TARGET static inline void Multiply(AccessT access, const Tile shape) { // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1. assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); @@ -82,13 +82,13 @@ template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET shape.B_cols - overhang.B_cols }; // Top left corner. - MultiplyNoOverhang<Access, Big>(access, big_shape); + MultiplyNoOverhang<Big>(access, big_shape); // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop. - MultiplyNoOverhang<Access, Kernel>( + MultiplyNoOverhang<Kernel>( access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0), Tile {overhang.A_rows, shape.inner, shape.B_cols}); // Right side except bottom. TODO: unrolled kernel, rather than dumb loop. - MultiplyNoOverhang<Access, Kernel>( + MultiplyNoOverhang<Kernel>( access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols), Tile {big_shape.A_rows, shape.inner, overhang.B_cols}); } |