diff options
author | benoitjacob <benoitjacob@google.com> | 2020-04-24 23:21:49 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-04-29 21:38:28 +0300 |
commit | c6a65efcc42cbed6db9c2c10194a5b0b7600be42 (patch) | |
tree | 3911b49f1b86c3920109d075955ca8ee2c3f9217 | |
parent | 0ad580f6721912047c3612f4b202f3aabf8bdea9 (diff) |
Use the new ruy API for caching constant matrices.test_308313346
PiperOrigin-RevId: 308313346
-rw-r--r-- | ruy/context.cc | 4 | ||||
-rw-r--r-- | ruy/context.h | 6 | ||||
-rw-r--r-- | ruy/context_test.cc | 3 | ||||
-rw-r--r-- | ruy/ctx.cc | 4 | ||||
-rw-r--r-- | ruy/ctx.h | 3 | ||||
-rw-r--r-- | ruy/ctx_impl.h | 1 | ||||
-rw-r--r-- | ruy/dispatch.h | 98 | ||||
-rw-r--r-- | ruy/mat.h | 33 | ||||
-rw-r--r-- | ruy/matrix.h | 38 | ||||
-rw-r--r-- | ruy/matrix_test.cc | 6 | ||||
-rw-r--r-- | ruy/prepacked_cache.h | 2 | ||||
-rw-r--r-- | ruy/prepacked_cache_test.cc | 25 | ||||
-rw-r--r-- | ruy/side_pair.h | 4 |
13 files changed, 137 insertions, 90 deletions
diff --git a/ruy/context.cc b/ruy/context.cc index 1f70a9f..0e751e8 100644 --- a/ruy/context.cc +++ b/ruy/context.cc @@ -48,10 +48,6 @@ const TracingContext& Context::tracing() const { return ctx().tracing(); } TracingContext* Context::mutable_tracing() { return mutable_ctx()->mutable_tracing(); } -CachePolicy Context::cache_policy() const { return ctx().cache_policy(); } -void Context::set_cache_policy(CachePolicy value) { - mutable_ctx()->set_cache_policy(value); -} void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); } diff --git a/ruy/context.h b/ruy/context.h index b283108..b41ab54 100644 --- a/ruy/context.h +++ b/ruy/context.h @@ -28,7 +28,7 @@ class ThreadPool; class TracingContext; enum class Path : std::uint8_t; enum class Tuning; -enum class CachePolicy; +enum class CachePolicy : std::uint8_t; // A Context holds runtime information used by Ruy. It holds runtime resources // such as the workers thread pool and the allocator (which holds buffers for @@ -49,8 +49,8 @@ class Context final { void set_max_num_threads(int value); const TracingContext& tracing() const; TracingContext* mutable_tracing(); - CachePolicy cache_policy() const; - void set_cache_policy(CachePolicy value); + void set_cache_policy(CachePolicy) { /* do nothing, legacy */ + } void ClearPrepackedCache(); diff --git a/ruy/context_test.cc b/ruy/context_test.cc index 5ca0c4f..d2cc3b4 100644 --- a/ruy/context_test.cc +++ b/ruy/context_test.cc @@ -31,13 +31,10 @@ TEST(ContextTest, ContextClassSanity) { EXPECT_NE(context.mutable_thread_pool(), nullptr); EXPECT_EQ(context.max_num_threads(), 1); EXPECT_EQ(&context.tracing(), context.mutable_tracing()); - EXPECT_EQ(context.cache_policy(), CachePolicy::kNoCache); context.set_explicit_tuning(Tuning::kOutOfOrder); context.set_max_num_threads(2); - context.set_cache_policy(CachePolicy::kCacheLHSOnNarrowMul); EXPECT_EQ(context.explicit_tuning(), Tuning::kOutOfOrder); EXPECT_EQ(context.max_num_threads(), 2); - EXPECT_EQ(context.cache_policy(), CachePolicy::kCacheLHSOnNarrowMul); } } // namespace @@ -40,10 +40,6 @@ void Ctx::set_max_num_threads(int value) { } const TracingContext& Ctx::tracing() const { return impl().tracing_; } TracingContext* Ctx::mutable_tracing() { return &mutable_impl()->tracing_; } -CachePolicy Ctx::cache_policy() const { return impl().cache_policy_; } -void Ctx::set_cache_policy(CachePolicy value) { - mutable_impl()->cache_policy_ = value; -} void Ctx::SetRuntimeEnabledPaths(Path paths) { mutable_impl()->runtime_enabled_paths_ = paths; @@ -40,7 +40,6 @@ class TuningResolver; class PrepackedCache; enum class Path : std::uint8_t; enum class Tuning; -enum class CachePolicy; // Ctx is the internal context class used throughout ruy code. Whereas Context // is exposed to users, Ctx is internal to ruy. As many of ruy's internal @@ -60,8 +59,6 @@ class Ctx /* not final, subclassed by CtxImpl */ { void set_max_num_threads(int value); const TracingContext& tracing() const; TracingContext* mutable_tracing(); - CachePolicy cache_policy() const; - void set_cache_policy(CachePolicy value); void SetRuntimeEnabledPaths(Path paths); Path GetRuntimeEnabledPaths(); diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h index 080e016..3352121 100644 --- a/ruy/ctx_impl.h +++ b/ruy/ctx_impl.h @@ -61,7 +61,6 @@ class CtxImpl final : public Ctx { ThreadPool thread_pool_; int max_num_threads_ = 1; TracingContext tracing_; - CachePolicy cache_policy_ = CachePolicy::kNoCache; // Allocator for main thread work before invoking the threadpool. // Our simple Allocator does not allow reserving/allocating more blocks // while it's already in committed state, so the main thread needs both diff --git a/ruy/dispatch.h b/ruy/dispatch.h index eb0b07b..383e52c 100644 --- a/ruy/dispatch.h +++ b/ruy/dispatch.h @@ -388,41 +388,72 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> { } }; -inline void HandlePrepackedCaching(TrMulParams* params, - const SidePair<bool>& cacheable, Ctx* ctx) { - if (ctx->cache_policy() == CachePolicy::kNoCache) { - return; +// Returns true if the operand on the given side should use caching of the +// packed form. This may either be explicitly dictated by its cache_policy +// (if it is kNeverCache, the default, or kAlwaysCache), or it may depend +// on a heuristic decision based on the other operand's width. For example, +// in a matrix*vector product, for the LHS matrix operand, the other side is +// the RHS vector, with a width of 1, causing the packing of the LHS to be +// a large fraction of the overall work, so a heuristic would typically +// decide in favor of caching, if permitted at all by the cache_policy. +inline bool ShouldCache(const TrMulParams& params, Side side) { + const CachePolicy cache_policy = params.src[side].cache_policy; + // The width that matters is that of the other side, it is what determines + // the amortization of the packing work done on the present side. + const Side other_side = Other(side); + const int other_width = params.src[other_side].layout.cols; + const int other_kernel_width = params.packed[other_side].layout.kernel.cols; + switch (cache_policy) { + case CachePolicy::kNeverCache: + return false; + case CachePolicy::kAlwaysCache: + return true; + case CachePolicy::kCacheIfLargeSpeedup: + // The condition (other_width <= other_kernel_width) means that the kernel + // will traverse each value of the present side only once, meaning that + // the overhead of the packing work will be maximal, hence maximally + // worth caching. + return (other_width <= other_kernel_width); + case CachePolicy::kCacheIfSignificantSpeedup: + // Variant of the heuristic used in the kCacheIfLargeSpeedup case. The + // kernel will run on each value of the present side only a few times, + // so packing overhead will be significant. + return (other_width <= 4 * other_kernel_width); + case CachePolicy::kCacheLikeTheOldCode: + return other_width <= 4; + default: + RUY_DCHECK(false); + return false; } +} - if (ctx->cache_policy() == CachePolicy::kCacheLHSOnNarrowMul) { - // TODO(b/149304278) Cache on dst.cols <= selected kernel width. - if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { - return; - } - PrepackedCache* prepacked_cache = ctx->GetPrepackedCache(); - auto cache_key = std::make_pair(reinterpret_cast<void*>(params->run_kernel), - params->src[Side::kLhs].data); - auto it = prepacked_cache->FindAndUpdate(cache_key); - if (it != prepacked_cache->cend()) { - params->packed[Side::kLhs].data = it->second.first.data; - params->packed[Side::kLhs].sums = it->second.first.sums; - params->is_prepacked[Side::kLhs] = true; - return; +inline void HandlePrepackedCaching(TrMulParams* params, Ctx* ctx) { + for (Side side : {Side::kLhs, Side::kRhs}) { + if (ShouldCache(*params, side)) { + // Look up in cache. + PrepackedCache* prepacked_cache = ctx->GetPrepackedCache(); + auto cache_key = std::make_pair( + reinterpret_cast<void*>(params->run_kernel), params->src[side].data); + auto it = prepacked_cache->FindAndUpdate(cache_key); + if (it != prepacked_cache->cend()) { + // Already cached. + params->packed[side].data = it->second.first.data; + params->packed[side].sums = it->second.first.sums; + params->is_prepacked[side] = true; + return; + } + // Not already cached. Pack and cache now. + PrepackedMatrix prepacked_lhs; + prepacked_lhs.data_size = DataSize(params->packed[side]); + prepacked_lhs.sums_size = SumsSize(params->packed[side]); + prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); + params->packed[side].data = prepacked_lhs.data; + params->packed[side].sums = prepacked_lhs.sums; + params->is_prepacked[side] = true; + Tuning tuning = ctx->GetMainThreadTuning(); + params->RunPack(side, tuning, 0, params->packed[side].layout.cols); + prepacked_cache->Insert(cache_key, prepacked_lhs); } - - // Allocate the prepacked matrix. - PrepackedMatrix prepacked_lhs; - prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); - prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); - prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); - params->packed[Side::kLhs].data = prepacked_lhs.data; - params->packed[Side::kLhs].sums = prepacked_lhs.sums; - params->is_prepacked[Side::kLhs] = true; - Tuning tuning = ctx->GetMainThreadTuning(); - params->RunPack(Side::kLhs, tuning, 0, - params->packed[Side::kLhs].layout.cols); - prepacked_cache->Insert(cache_key, prepacked_lhs); - return; } } @@ -477,8 +508,7 @@ void DispatchMul(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, TrMulParams params; CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, mul_params, dst, the_path, ¶ms); - SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable); - HandlePrepackedCaching(¶ms, cacheable, ctx); + HandlePrepackedCaching(¶ms, ctx); TrMul(¶ms, ctx); } @@ -135,7 +135,7 @@ struct Mat final { detail::ConstCheckingPtr<Scalar> data; MatLayout layout; Scalar zero_point = 0; - bool cacheable = false; + CachePolicy cache_policy = CachePolicy::kNeverCache; }; template <typename Scalar> @@ -144,7 +144,7 @@ inline Mat<Scalar> ToInternal(const Matrix<Scalar>& src) { ret.data.set(src.data()); ret.layout = ToInternal(src.layout()); ret.zero_point = src.zero_point(); - ret.cacheable = src.cacheable(); + ret.cache_policy = src.cache_policy(); return ret; } @@ -154,7 +154,7 @@ inline Mat<Scalar> ToInternal(Matrix<Scalar>& src) { ret.data.set(src.data()); ret.layout = ToInternal(src.layout()); ret.zero_point = src.zero_point(); - ret.cacheable = src.cacheable(); + ret.cache_policy = src.cache_policy(); return ret; } @@ -235,6 +235,7 @@ struct EMat final { void* data = nullptr; MatLayout layout; std::int32_t zero_point = 0; + CachePolicy cache_policy = CachePolicy::kNeverCache; }; // Type-erased packed matrix. @@ -277,29 +278,31 @@ EMat EraseType(const Mat<T>& matrix) { ret.data = ToVoidPtr(matrix.data.get()); ret.layout = matrix.layout; ret.zero_point = matrix.zero_point; + ret.cache_policy = matrix.cache_policy; return ret; } template <typename T> -Mat<T> UneraseType(const EMat& dmatrix) { - dmatrix.data_type.AssertIs<T>(); +Mat<T> UneraseType(const EMat& matrix) { + matrix.data_type.AssertIs<T>(); Mat<T> ret; - ret.data.set(static_cast<T*>(dmatrix.data)); - ret.layout = dmatrix.layout; - ret.zero_point = dmatrix.zero_point; + ret.data.set(static_cast<T*>(matrix.data)); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + ret.cache_policy = matrix.cache_policy; return ret; } template <typename T> -PMat<T> UneraseType(const PEMat& pmatrix) { +PMat<T> UneraseType(const PEMat& matrix) { using SumsType = typename PMat<T>::SumsType; - pmatrix.data_type.AssertIs<T>(); - pmatrix.sums_type.AssertIs<SumsType>(); + matrix.data_type.AssertIs<T>(); + matrix.sums_type.AssertIs<SumsType>(); PMat<T> ret; - ret.data = static_cast<T*>(pmatrix.data); - ret.sums = static_cast<SumsType*>(pmatrix.sums); - ret.layout = pmatrix.layout; - ret.zero_point = pmatrix.zero_point; + ret.data = static_cast<T*>(matrix.data); + ret.sums = static_cast<SumsType*>(matrix.sums); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; return ret; } diff --git a/ruy/matrix.h b/ruy/matrix.h index 5f480b7..b63a22b 100644 --- a/ruy/matrix.h +++ b/ruy/matrix.h @@ -91,9 +91,7 @@ class ConstCheckingPtr final { ptr_ = ptr; set_mutable(false); } - void set(std::nullptr_t) { - ptr_ = nullptr; - } + void set(std::nullptr_t) { ptr_ = nullptr; } T* get() /* NOT const */ { assert_mutable(); return const_cast<T*>(ptr_); @@ -116,6 +114,18 @@ class ConstCheckingPtr final { } // namespace detail +enum class CachePolicy : std::uint8_t { + kNeverCache, + kCacheIfLargeSpeedup, + kCacheIfSignificantSpeedup, + kAlwaysCache, + // transitional value emulating old behavior + kCacheLikeTheOldCode, + // legacy values used when the CachePolicy was a Context property + kNoCache, + kCacheLHSOnNarrowMul +}; + // A Matrix merely wraps existing data as a matrix. It doesn't own any buffer. // The purpose of Matrix is only to be used in ruy's interface -- it's just // a structured way for the user to pass to ruy::Mul the matrix data pointers @@ -141,8 +151,15 @@ class Matrix final { Layout* mutable_layout() { return &layout_; } Scalar zero_point() const { return zero_point_; } void set_zero_point(Scalar value) { zero_point_ = value; } - bool cacheable() const { return cacheable_; } - void set_cacheable(bool value) { cacheable_ = value; } + CachePolicy cache_policy() const { return cache_policy_; } + void set_cache_policy(CachePolicy value) { cache_policy_ = value; } + + // legacy for compatibily, essentially preserving old behavior for existing + // callers during the transition to set_cache_threshold. + void set_cacheable(bool value) { + set_cache_policy(value ? CachePolicy::kCacheLikeTheOldCode + : CachePolicy::kNeverCache); + } private: // The underlying buffer wrapped by this matrix. @@ -152,10 +169,13 @@ class Matrix final { // The zero_point, i.e. which Scalar value is to be interpreted as zero. // When Scalar is floating-point, this must be 0. Scalar zero_point_ = 0; - // Clients of Ruy must set this flag to enable any caching behavior. Doesn't - // impact numerical results, but caching can impact observable metrics like - // latency, memory usage, power, etc. - bool cacheable_ = false; + // When the data pointed to by this matrix is constant data, so that it is + // valid to assume that equality of pointers implies equality of data, + // a CachePolicy may be used instead of the default kNeverCache, + // which will enable ruy to take advantage of this constancy of the data to + // cache the packing work, which can be a large speedup in matrix*vector + // and other narrow shapes. + CachePolicy cache_policy_ = CachePolicy::kNeverCache; }; inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { diff --git a/ruy/matrix_test.cc b/ruy/matrix_test.cc index 04ee7ae..0f3fd13 100644 --- a/ruy/matrix_test.cc +++ b/ruy/matrix_test.cc @@ -70,7 +70,7 @@ TEST(MatrixTest, MatrixClassSanity) { Matrix<int> matrix; EXPECT_EQ(matrix.data(), nullptr); EXPECT_EQ(matrix.zero_point(), 0); - EXPECT_EQ(matrix.cacheable(), false); + EXPECT_EQ(matrix.cache_policy(), CachePolicy::kNeverCache); EXPECT_EQ(matrix.layout().rows(), 0); EXPECT_EQ(matrix.layout().cols(), 0); EXPECT_EQ(matrix.layout().stride(), 0); @@ -78,14 +78,14 @@ TEST(MatrixTest, MatrixClassSanity) { const int some_const = 0; matrix.set_data(&some_const); matrix.set_zero_point(123); - matrix.set_cacheable(true); + matrix.set_cache_policy(CachePolicy::kAlwaysCache); MakeSimpleLayout(12, 34, Order::kRowMajor, matrix.mutable_layout()); EXPECT_EQ(static_cast<const Matrix<int>&>(matrix).data(), &some_const); #ifndef NDEBUG RUY_ASSERT_DEATH(matrix.data(), ""); #endif EXPECT_EQ(matrix.zero_point(), 123); - EXPECT_EQ(matrix.cacheable(), true); + EXPECT_EQ(matrix.cache_policy(), CachePolicy::kAlwaysCache); EXPECT_EQ(matrix.layout().rows(), 12); EXPECT_EQ(matrix.layout().cols(), 34); EXPECT_EQ(matrix.layout().stride(), 34); diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h index f8c731f..0e0fdda 100644 --- a/ruy/prepacked_cache.h +++ b/ruy/prepacked_cache.h @@ -62,8 +62,6 @@ class SystemBlockAllocator { } // namespace detail -enum class CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; - // "Low effort" Least Recently Used Cache for Prepacked Matrices // A cache mechanism for prepacked matrices that ejects oldest entries. // The implementation is "low effort" in the following ways: diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc index 9459b25..5b74236 100644 --- a/ruy/prepacked_cache_test.cc +++ b/ruy/prepacked_cache_test.cc @@ -136,11 +136,9 @@ TEST(PrepackedCacheTest, TestCacheEjection2) { EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); } -TEST(PrepackedCacheTest, TestCacheOnCacheable) { - // Create context and set the cache policy +void TestCacheOnCacheable(CachePolicy cache_policy, bool expected_cached) { ruy::Context context; ruy::Ctx* ctx = get_ctx(&context); - context.set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul); PrepackedCache* cache = ctx->GetPrepackedCache(); EXPECT_EQ(cache->TotalSize(), 0); @@ -163,17 +161,26 @@ TEST(PrepackedCacheTest, TestCacheOnCacheable) { ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); EXPECT_EQ(cache->TotalSize(), 0); - // Set cacheable for the LHS, repeat the multiplication, and see + // Set cache policy for the LHS, repeat the multiplication, and see // that caching did occur. - lhs.set_cacheable(true); + lhs.set_cache_policy(cache_policy); ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); + const bool actual_cached = cache->TotalSize() > 0; + EXPECT_EQ(actual_cached, expected_cached); +} + +TEST(PrepackedCacheTest, TestCacheOnCacheable) { + for (CachePolicy cache_policy : + {CachePolicy::kNeverCache, CachePolicy::kCacheIfLargeSpeedup, + CachePolicy::kCacheIfSignificantSpeedup, CachePolicy::kAlwaysCache, + CachePolicy::kCacheLikeTheOldCode}) { + TestCacheOnCacheable(cache_policy, + cache_policy != CachePolicy::kNeverCache); + } } TEST(PrepackedCacheTest, TestClearCache) { - // Create context and set the cache policy ruy::Context context; - context.set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul); PrepackedCache* cache = get_ctx(&context)->GetPrepackedCache(); EXPECT_EQ(cache->TotalSize(), 0); @@ -193,7 +200,7 @@ TEST(PrepackedCacheTest, TestClearCache) { ruy::MulParams<float, float> mul_params; // Set cacheable for the LHS and see that caching occurs. - lhs.set_cacheable(true); + lhs.set_cache_policy(CachePolicy::kAlwaysCache); ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst); EXPECT_NE(cache->TotalSize(), 0); diff --git a/ruy/side_pair.h b/ruy/side_pair.h index 212c054..f0d0973 100644 --- a/ruy/side_pair.h +++ b/ruy/side_pair.h @@ -30,6 +30,10 @@ enum class Side { kRhs = 1 }; +inline Side Other(Side side) { + return side == Side::kLhs ? Side::kRhs : Side::kLhs; +} + // SidePair is a pair container where the two elements are indexed by a Side // enum. template <typename T> |