Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbenoitjacob <benoitjacob@google.com>2020-04-24 23:21:49 +0300
committerCopybara-Service <copybara-worker@google.com>2020-04-29 21:38:28 +0300
commitc6a65efcc42cbed6db9c2c10194a5b0b7600be42 (patch)
tree3911b49f1b86c3920109d075955ca8ee2c3f9217
parent0ad580f6721912047c3612f4b202f3aabf8bdea9 (diff)
Use the new ruy API for caching constant matrices.test_308313346
PiperOrigin-RevId: 308313346
-rw-r--r--ruy/context.cc4
-rw-r--r--ruy/context.h6
-rw-r--r--ruy/context_test.cc3
-rw-r--r--ruy/ctx.cc4
-rw-r--r--ruy/ctx.h3
-rw-r--r--ruy/ctx_impl.h1
-rw-r--r--ruy/dispatch.h98
-rw-r--r--ruy/mat.h33
-rw-r--r--ruy/matrix.h38
-rw-r--r--ruy/matrix_test.cc6
-rw-r--r--ruy/prepacked_cache.h2
-rw-r--r--ruy/prepacked_cache_test.cc25
-rw-r--r--ruy/side_pair.h4
13 files changed, 137 insertions, 90 deletions
diff --git a/ruy/context.cc b/ruy/context.cc
index 1f70a9f..0e751e8 100644
--- a/ruy/context.cc
+++ b/ruy/context.cc
@@ -48,10 +48,6 @@ const TracingContext& Context::tracing() const { return ctx().tracing(); }
TracingContext* Context::mutable_tracing() {
return mutable_ctx()->mutable_tracing();
}
-CachePolicy Context::cache_policy() const { return ctx().cache_policy(); }
-void Context::set_cache_policy(CachePolicy value) {
- mutable_ctx()->set_cache_policy(value);
-}
void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); }
diff --git a/ruy/context.h b/ruy/context.h
index b283108..b41ab54 100644
--- a/ruy/context.h
+++ b/ruy/context.h
@@ -28,7 +28,7 @@ class ThreadPool;
class TracingContext;
enum class Path : std::uint8_t;
enum class Tuning;
-enum class CachePolicy;
+enum class CachePolicy : std::uint8_t;
// A Context holds runtime information used by Ruy. It holds runtime resources
// such as the workers thread pool and the allocator (which holds buffers for
@@ -49,8 +49,8 @@ class Context final {
void set_max_num_threads(int value);
const TracingContext& tracing() const;
TracingContext* mutable_tracing();
- CachePolicy cache_policy() const;
- void set_cache_policy(CachePolicy value);
+ void set_cache_policy(CachePolicy) { /* do nothing, legacy */
+ }
void ClearPrepackedCache();
diff --git a/ruy/context_test.cc b/ruy/context_test.cc
index 5ca0c4f..d2cc3b4 100644
--- a/ruy/context_test.cc
+++ b/ruy/context_test.cc
@@ -31,13 +31,10 @@ TEST(ContextTest, ContextClassSanity) {
EXPECT_NE(context.mutable_thread_pool(), nullptr);
EXPECT_EQ(context.max_num_threads(), 1);
EXPECT_EQ(&context.tracing(), context.mutable_tracing());
- EXPECT_EQ(context.cache_policy(), CachePolicy::kNoCache);
context.set_explicit_tuning(Tuning::kOutOfOrder);
context.set_max_num_threads(2);
- context.set_cache_policy(CachePolicy::kCacheLHSOnNarrowMul);
EXPECT_EQ(context.explicit_tuning(), Tuning::kOutOfOrder);
EXPECT_EQ(context.max_num_threads(), 2);
- EXPECT_EQ(context.cache_policy(), CachePolicy::kCacheLHSOnNarrowMul);
}
} // namespace
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index fccdf9c..d6842df 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -40,10 +40,6 @@ void Ctx::set_max_num_threads(int value) {
}
const TracingContext& Ctx::tracing() const { return impl().tracing_; }
TracingContext* Ctx::mutable_tracing() { return &mutable_impl()->tracing_; }
-CachePolicy Ctx::cache_policy() const { return impl().cache_policy_; }
-void Ctx::set_cache_policy(CachePolicy value) {
- mutable_impl()->cache_policy_ = value;
-}
void Ctx::SetRuntimeEnabledPaths(Path paths) {
mutable_impl()->runtime_enabled_paths_ = paths;
diff --git a/ruy/ctx.h b/ruy/ctx.h
index 6fc62e8..393f7a4 100644
--- a/ruy/ctx.h
+++ b/ruy/ctx.h
@@ -40,7 +40,6 @@ class TuningResolver;
class PrepackedCache;
enum class Path : std::uint8_t;
enum class Tuning;
-enum class CachePolicy;
// Ctx is the internal context class used throughout ruy code. Whereas Context
// is exposed to users, Ctx is internal to ruy. As many of ruy's internal
@@ -60,8 +59,6 @@ class Ctx /* not final, subclassed by CtxImpl */ {
void set_max_num_threads(int value);
const TracingContext& tracing() const;
TracingContext* mutable_tracing();
- CachePolicy cache_policy() const;
- void set_cache_policy(CachePolicy value);
void SetRuntimeEnabledPaths(Path paths);
Path GetRuntimeEnabledPaths();
diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h
index 080e016..3352121 100644
--- a/ruy/ctx_impl.h
+++ b/ruy/ctx_impl.h
@@ -61,7 +61,6 @@ class CtxImpl final : public Ctx {
ThreadPool thread_pool_;
int max_num_threads_ = 1;
TracingContext tracing_;
- CachePolicy cache_policy_ = CachePolicy::kNoCache;
// Allocator for main thread work before invoking the threadpool.
// Our simple Allocator does not allow reserving/allocating more blocks
// while it's already in committed state, so the main thread needs both
diff --git a/ruy/dispatch.h b/ruy/dispatch.h
index eb0b07b..383e52c 100644
--- a/ruy/dispatch.h
+++ b/ruy/dispatch.h
@@ -388,41 +388,72 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
}
};
-inline void HandlePrepackedCaching(TrMulParams* params,
- const SidePair<bool>& cacheable, Ctx* ctx) {
- if (ctx->cache_policy() == CachePolicy::kNoCache) {
- return;
+// Returns true if the operand on the given side should use caching of the
+// packed form. This may either be explicitly dictated by its cache_policy
+// (if it is kNeverCache, the default, or kAlwaysCache), or it may depend
+// on a heuristic decision based on the other operand's width. For example,
+// in a matrix*vector product, for the LHS matrix operand, the other side is
+// the RHS vector, with a width of 1, causing the packing of the LHS to be
+// a large fraction of the overall work, so a heuristic would typically
+// decide in favor of caching, if permitted at all by the cache_policy.
+inline bool ShouldCache(const TrMulParams& params, Side side) {
+ const CachePolicy cache_policy = params.src[side].cache_policy;
+ // The width that matters is that of the other side, it is what determines
+ // the amortization of the packing work done on the present side.
+ const Side other_side = Other(side);
+ const int other_width = params.src[other_side].layout.cols;
+ const int other_kernel_width = params.packed[other_side].layout.kernel.cols;
+ switch (cache_policy) {
+ case CachePolicy::kNeverCache:
+ return false;
+ case CachePolicy::kAlwaysCache:
+ return true;
+ case CachePolicy::kCacheIfLargeSpeedup:
+ // The condition (other_width <= other_kernel_width) means that the kernel
+ // will traverse each value of the present side only once, meaning that
+ // the overhead of the packing work will be maximal, hence maximally
+ // worth caching.
+ return (other_width <= other_kernel_width);
+ case CachePolicy::kCacheIfSignificantSpeedup:
+ // Variant of the heuristic used in the kCacheIfLargeSpeedup case. The
+ // kernel will run on each value of the present side only a few times,
+ // so packing overhead will be significant.
+ return (other_width <= 4 * other_kernel_width);
+ case CachePolicy::kCacheLikeTheOldCode:
+ return other_width <= 4;
+ default:
+ RUY_DCHECK(false);
+ return false;
}
+}
- if (ctx->cache_policy() == CachePolicy::kCacheLHSOnNarrowMul) {
- // TODO(b/149304278) Cache on dst.cols <= selected kernel width.
- if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) {
- return;
- }
- PrepackedCache* prepacked_cache = ctx->GetPrepackedCache();
- auto cache_key = std::make_pair(reinterpret_cast<void*>(params->run_kernel),
- params->src[Side::kLhs].data);
- auto it = prepacked_cache->FindAndUpdate(cache_key);
- if (it != prepacked_cache->cend()) {
- params->packed[Side::kLhs].data = it->second.first.data;
- params->packed[Side::kLhs].sums = it->second.first.sums;
- params->is_prepacked[Side::kLhs] = true;
- return;
+inline void HandlePrepackedCaching(TrMulParams* params, Ctx* ctx) {
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (ShouldCache(*params, side)) {
+ // Look up in cache.
+ PrepackedCache* prepacked_cache = ctx->GetPrepackedCache();
+ auto cache_key = std::make_pair(
+ reinterpret_cast<void*>(params->run_kernel), params->src[side].data);
+ auto it = prepacked_cache->FindAndUpdate(cache_key);
+ if (it != prepacked_cache->cend()) {
+ // Already cached.
+ params->packed[side].data = it->second.first.data;
+ params->packed[side].sums = it->second.first.sums;
+ params->is_prepacked[side] = true;
+ return;
+ }
+ // Not already cached. Pack and cache now.
+ PrepackedMatrix prepacked_lhs;
+ prepacked_lhs.data_size = DataSize(params->packed[side]);
+ prepacked_lhs.sums_size = SumsSize(params->packed[side]);
+ prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs);
+ params->packed[side].data = prepacked_lhs.data;
+ params->packed[side].sums = prepacked_lhs.sums;
+ params->is_prepacked[side] = true;
+ Tuning tuning = ctx->GetMainThreadTuning();
+ params->RunPack(side, tuning, 0, params->packed[side].layout.cols);
+ prepacked_cache->Insert(cache_key, prepacked_lhs);
}
-
- // Allocate the prepacked matrix.
- PrepackedMatrix prepacked_lhs;
- prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]);
- prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]);
- prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs);
- params->packed[Side::kLhs].data = prepacked_lhs.data;
- params->packed[Side::kLhs].sums = prepacked_lhs.sums;
- params->is_prepacked[Side::kLhs] = true;
- Tuning tuning = ctx->GetMainThreadTuning();
- params->RunPack(Side::kLhs, tuning, 0,
- params->packed[Side::kLhs].layout.cols);
- prepacked_cache->Insert(cache_key, prepacked_lhs);
- return;
}
}
@@ -477,8 +508,7 @@ void DispatchMul(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, mul_params, dst,
the_path, &params);
- SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
- HandlePrepackedCaching(&params, cacheable, ctx);
+ HandlePrepackedCaching(&params, ctx);
TrMul(&params, ctx);
}
diff --git a/ruy/mat.h b/ruy/mat.h
index ba8bc38..1c88272 100644
--- a/ruy/mat.h
+++ b/ruy/mat.h
@@ -135,7 +135,7 @@ struct Mat final {
detail::ConstCheckingPtr<Scalar> data;
MatLayout layout;
Scalar zero_point = 0;
- bool cacheable = false;
+ CachePolicy cache_policy = CachePolicy::kNeverCache;
};
template <typename Scalar>
@@ -144,7 +144,7 @@ inline Mat<Scalar> ToInternal(const Matrix<Scalar>& src) {
ret.data.set(src.data());
ret.layout = ToInternal(src.layout());
ret.zero_point = src.zero_point();
- ret.cacheable = src.cacheable();
+ ret.cache_policy = src.cache_policy();
return ret;
}
@@ -154,7 +154,7 @@ inline Mat<Scalar> ToInternal(Matrix<Scalar>& src) {
ret.data.set(src.data());
ret.layout = ToInternal(src.layout());
ret.zero_point = src.zero_point();
- ret.cacheable = src.cacheable();
+ ret.cache_policy = src.cache_policy();
return ret;
}
@@ -235,6 +235,7 @@ struct EMat final {
void* data = nullptr;
MatLayout layout;
std::int32_t zero_point = 0;
+ CachePolicy cache_policy = CachePolicy::kNeverCache;
};
// Type-erased packed matrix.
@@ -277,29 +278,31 @@ EMat EraseType(const Mat<T>& matrix) {
ret.data = ToVoidPtr(matrix.data.get());
ret.layout = matrix.layout;
ret.zero_point = matrix.zero_point;
+ ret.cache_policy = matrix.cache_policy;
return ret;
}
template <typename T>
-Mat<T> UneraseType(const EMat& dmatrix) {
- dmatrix.data_type.AssertIs<T>();
+Mat<T> UneraseType(const EMat& matrix) {
+ matrix.data_type.AssertIs<T>();
Mat<T> ret;
- ret.data.set(static_cast<T*>(dmatrix.data));
- ret.layout = dmatrix.layout;
- ret.zero_point = dmatrix.zero_point;
+ ret.data.set(static_cast<T*>(matrix.data));
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
+ ret.cache_policy = matrix.cache_policy;
return ret;
}
template <typename T>
-PMat<T> UneraseType(const PEMat& pmatrix) {
+PMat<T> UneraseType(const PEMat& matrix) {
using SumsType = typename PMat<T>::SumsType;
- pmatrix.data_type.AssertIs<T>();
- pmatrix.sums_type.AssertIs<SumsType>();
+ matrix.data_type.AssertIs<T>();
+ matrix.sums_type.AssertIs<SumsType>();
PMat<T> ret;
- ret.data = static_cast<T*>(pmatrix.data);
- ret.sums = static_cast<SumsType*>(pmatrix.sums);
- ret.layout = pmatrix.layout;
- ret.zero_point = pmatrix.zero_point;
+ ret.data = static_cast<T*>(matrix.data);
+ ret.sums = static_cast<SumsType*>(matrix.sums);
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
return ret;
}
diff --git a/ruy/matrix.h b/ruy/matrix.h
index 5f480b7..b63a22b 100644
--- a/ruy/matrix.h
+++ b/ruy/matrix.h
@@ -91,9 +91,7 @@ class ConstCheckingPtr final {
ptr_ = ptr;
set_mutable(false);
}
- void set(std::nullptr_t) {
- ptr_ = nullptr;
- }
+ void set(std::nullptr_t) { ptr_ = nullptr; }
T* get() /* NOT const */ {
assert_mutable();
return const_cast<T*>(ptr_);
@@ -116,6 +114,18 @@ class ConstCheckingPtr final {
} // namespace detail
+enum class CachePolicy : std::uint8_t {
+ kNeverCache,
+ kCacheIfLargeSpeedup,
+ kCacheIfSignificantSpeedup,
+ kAlwaysCache,
+ // transitional value emulating old behavior
+ kCacheLikeTheOldCode,
+ // legacy values used when the CachePolicy was a Context property
+ kNoCache,
+ kCacheLHSOnNarrowMul
+};
+
// A Matrix merely wraps existing data as a matrix. It doesn't own any buffer.
// The purpose of Matrix is only to be used in ruy's interface -- it's just
// a structured way for the user to pass to ruy::Mul the matrix data pointers
@@ -141,8 +151,15 @@ class Matrix final {
Layout* mutable_layout() { return &layout_; }
Scalar zero_point() const { return zero_point_; }
void set_zero_point(Scalar value) { zero_point_ = value; }
- bool cacheable() const { return cacheable_; }
- void set_cacheable(bool value) { cacheable_ = value; }
+ CachePolicy cache_policy() const { return cache_policy_; }
+ void set_cache_policy(CachePolicy value) { cache_policy_ = value; }
+
+ // legacy for compatibily, essentially preserving old behavior for existing
+ // callers during the transition to set_cache_threshold.
+ void set_cacheable(bool value) {
+ set_cache_policy(value ? CachePolicy::kCacheLikeTheOldCode
+ : CachePolicy::kNeverCache);
+ }
private:
// The underlying buffer wrapped by this matrix.
@@ -152,10 +169,13 @@ class Matrix final {
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point_ = 0;
- // Clients of Ruy must set this flag to enable any caching behavior. Doesn't
- // impact numerical results, but caching can impact observable metrics like
- // latency, memory usage, power, etc.
- bool cacheable_ = false;
+ // When the data pointed to by this matrix is constant data, so that it is
+ // valid to assume that equality of pointers implies equality of data,
+ // a CachePolicy may be used instead of the default kNeverCache,
+ // which will enable ruy to take advantage of this constancy of the data to
+ // cache the packing work, which can be a large speedup in matrix*vector
+ // and other narrow shapes.
+ CachePolicy cache_policy_ = CachePolicy::kNeverCache;
};
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
diff --git a/ruy/matrix_test.cc b/ruy/matrix_test.cc
index 04ee7ae..0f3fd13 100644
--- a/ruy/matrix_test.cc
+++ b/ruy/matrix_test.cc
@@ -70,7 +70,7 @@ TEST(MatrixTest, MatrixClassSanity) {
Matrix<int> matrix;
EXPECT_EQ(matrix.data(), nullptr);
EXPECT_EQ(matrix.zero_point(), 0);
- EXPECT_EQ(matrix.cacheable(), false);
+ EXPECT_EQ(matrix.cache_policy(), CachePolicy::kNeverCache);
EXPECT_EQ(matrix.layout().rows(), 0);
EXPECT_EQ(matrix.layout().cols(), 0);
EXPECT_EQ(matrix.layout().stride(), 0);
@@ -78,14 +78,14 @@ TEST(MatrixTest, MatrixClassSanity) {
const int some_const = 0;
matrix.set_data(&some_const);
matrix.set_zero_point(123);
- matrix.set_cacheable(true);
+ matrix.set_cache_policy(CachePolicy::kAlwaysCache);
MakeSimpleLayout(12, 34, Order::kRowMajor, matrix.mutable_layout());
EXPECT_EQ(static_cast<const Matrix<int>&>(matrix).data(), &some_const);
#ifndef NDEBUG
RUY_ASSERT_DEATH(matrix.data(), "");
#endif
EXPECT_EQ(matrix.zero_point(), 123);
- EXPECT_EQ(matrix.cacheable(), true);
+ EXPECT_EQ(matrix.cache_policy(), CachePolicy::kAlwaysCache);
EXPECT_EQ(matrix.layout().rows(), 12);
EXPECT_EQ(matrix.layout().cols(), 34);
EXPECT_EQ(matrix.layout().stride(), 34);
diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h
index f8c731f..0e0fdda 100644
--- a/ruy/prepacked_cache.h
+++ b/ruy/prepacked_cache.h
@@ -62,8 +62,6 @@ class SystemBlockAllocator {
} // namespace detail
-enum class CachePolicy { kNoCache, kCacheLHSOnNarrowMul };
-
// "Low effort" Least Recently Used Cache for Prepacked Matrices
// A cache mechanism for prepacked matrices that ejects oldest entries.
// The implementation is "low effort" in the following ways:
diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc
index 9459b25..5b74236 100644
--- a/ruy/prepacked_cache_test.cc
+++ b/ruy/prepacked_cache_test.cc
@@ -136,11 +136,9 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
}
-TEST(PrepackedCacheTest, TestCacheOnCacheable) {
- // Create context and set the cache policy
+void TestCacheOnCacheable(CachePolicy cache_policy, bool expected_cached) {
ruy::Context context;
ruy::Ctx* ctx = get_ctx(&context);
- context.set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
PrepackedCache* cache = ctx->GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
@@ -163,17 +161,26 @@ TEST(PrepackedCacheTest, TestCacheOnCacheable) {
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
EXPECT_EQ(cache->TotalSize(), 0);
- // Set cacheable for the LHS, repeat the multiplication, and see
+ // Set cache policy for the LHS, repeat the multiplication, and see
// that caching did occur.
- lhs.set_cacheable(true);
+ lhs.set_cache_policy(cache_policy);
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
- EXPECT_NE(cache->TotalSize(), 0);
+ const bool actual_cached = cache->TotalSize() > 0;
+ EXPECT_EQ(actual_cached, expected_cached);
+}
+
+TEST(PrepackedCacheTest, TestCacheOnCacheable) {
+ for (CachePolicy cache_policy :
+ {CachePolicy::kNeverCache, CachePolicy::kCacheIfLargeSpeedup,
+ CachePolicy::kCacheIfSignificantSpeedup, CachePolicy::kAlwaysCache,
+ CachePolicy::kCacheLikeTheOldCode}) {
+ TestCacheOnCacheable(cache_policy,
+ cache_policy != CachePolicy::kNeverCache);
+ }
}
TEST(PrepackedCacheTest, TestClearCache) {
- // Create context and set the cache policy
ruy::Context context;
- context.set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
PrepackedCache* cache = get_ctx(&context)->GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
@@ -193,7 +200,7 @@ TEST(PrepackedCacheTest, TestClearCache) {
ruy::MulParams<float, float> mul_params;
// Set cacheable for the LHS and see that caching occurs.
- lhs.set_cacheable(true);
+ lhs.set_cache_policy(CachePolicy::kAlwaysCache);
ruy::Mul<ruy::kAllPaths>(lhs, rhs, mul_params, &context, &dst);
EXPECT_NE(cache->TotalSize(), 0);
diff --git a/ruy/side_pair.h b/ruy/side_pair.h
index 212c054..f0d0973 100644
--- a/ruy/side_pair.h
+++ b/ruy/side_pair.h
@@ -30,6 +30,10 @@ enum class Side {
kRhs = 1
};
+inline Side Other(Side side) {
+ return side == Side::kLhs ? Side::kRhs : Side::kLhs;
+}
+
// SidePair is a pair container where the two elements are indexed by a Side
// enum.
template <typename T>