Welcome to mirror list, hosted at ThFree Co, Russian Federation.

dispatch.h « ruy - github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 2fd50d06fc4b4b548d7ccd3ce62336da7a14673d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
/* Copyright 2019 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements the translation between Ruy's entry point (ruy::Mul) and
// the internal implementation of matrix multiplication.
//
// The primary elements of this dispatch are:
// - pick suitable gemm kernel and packing routines for the user-specified
// CompiledPaths based on the current CPU.
// - decide on the structure of the packed matrices needed by the internal
// implementation (see pack.h for more information on packing).
// - translate the Mul operation into TrMul (see trmul.h for why that is
// useful). This is done by changing the matrix Layout -- no matrix data is
// actually moved.
//
// This file is also factored to serve as a building block for the advanced API
// as well.
//
// This file also performs some checking of invariants to catch user errors.

#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_

#include <algorithm>
#include <cstdint>
#include <limits>  // IWYU pragma: keep
#include <type_traits>

#include "ruy/check_macros.h"
#include "ruy/common.h"
#include "ruy/context.h"
#include "ruy/internal_matrix.h"
#include "ruy/kernel.h"
#include "ruy/kernel_common.h"
#include "ruy/matrix.h"
#include "ruy/opt_set.h"
#include "ruy/pack.h"
#include "ruy/pack_common.h"
#include "ruy/path.h"
#include "ruy/profiler/instrumentation.h"
#include "ruy/side_pair.h"
#include "ruy/size_util.h"
#include "ruy/spec.h"
#include "ruy/trmul.h"
#include "ruy/trmul_params.h"

namespace ruy {

// If the Spec's LayoutSupport covers only some special cases,
// this function enforces that the matrix multiplication at hand falls into
// that special case.
template <typename Spec>
void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout,
                          const Layout& dst_layout) {
  if (Spec::kLayoutSupport == LayoutSupport::kRCC) {
    RUY_DCHECK(IsRowMajor(lhs_layout));
    RUY_DCHECK(IsColMajor(rhs_layout));
    RUY_DCHECK(IsColMajor(dst_layout));
  }
}

template <typename Scalar>
bool IsSymmetricZeroPoint(Scalar zero_point) {
  return zero_point == SymmetricZeroPoint<Scalar>();
}

template <typename Spec, typename Scalar>
void CheckZeroPoint(Scalar zero_point) {
  if (std::is_floating_point<Scalar>::value ||
      Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) {
    RUY_DCHECK(IsSymmetricZeroPoint(zero_point));
  }
}

template <typename Spec, typename LhsScalar, typename RhsScalar,
          typename DstScalar>
void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
                             DstScalar dst_zero_point) {
  // If the Spec's ZeroPointSupport covers only some special cases,
  // this function enforces that the matrix multiplication at hand falls into
  // that special case.
  CheckZeroPoint<Spec>(lhs_zero_point);
  CheckZeroPoint<Spec>(rhs_zero_point);
  CheckZeroPoint<Spec>(dst_zero_point);

  // Guard against the case when both LHS and RHS zero_point's are equal to
  // the minimum representable value. In that case, padding with zero_point
  // values will generate the bad case for fast int8 kernels on NEON
  // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
  // into a int16:  this is safe except in the bad case -128*-128 + -128*-128.
  // See b/131609283. This only affects the kNeon path but we ban this for all
  // paths in order for ruy to have the same supported parameter space
  // on all paths.
  RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() ||
             rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
}

template <typename Spec, typename DstScalar>
void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) {
  static_assert(std::is_same<typename Spec::DstScalar, DstScalar>::value, "");
  if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) return;

  // If user is looking for the raw accumulator, zero_point and all the other
  // dequantize fields don't make sense and should not be set.
  RUY_DCHECK_EQ(dst_zero_point, 0);
  RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits<std::int32_t>::max());
  RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits<std::int32_t>::min());
  RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
  RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
  RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr);
  RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr);
}

inline bool IsColMajorTrMul(const TrMulParams& params) {
  return IsColMajor(params.src[Side::kLhs].layout) &&
         IsColMajor(params.src[Side::kRhs].layout) &&
         IsColMajor(params.dst.layout);
}

inline void CreatePackedLayout(const Layout& src, const Type& scalar,
                               const KernelLayout& kernel_layout,
                               PackedLayout* packed) {
  packed->order = Order::kColMajor;
  packed->rows = round_up_pot(src.rows, kernel_layout.rows);
  packed->cols = round_up_pot(src.cols, kernel_layout.cols);
  packed->kernel = kernel_layout;
  int inner_size = packed->rows;
  if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) {
    packed->stride =
        (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64;
  } else {
    packed->stride = inner_size;
  }
}

template <typename Scalar, typename PackedScalar>
void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
                        TrMulParams* params) {
  // Ruy always uses 32-bit signed accumulators for quantized
  // matrix multiplication, so we would like to always use std::int32_t
  // unconditionally for SumsType.
  // However, for floating point types, we still need a reasonable type here to
  // avoid tripping assertions elsewhere in the code.
  using SumsType =
      typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
                                std::int32_t>::type;

  const DMatrix& src = params->src[side];
  PMatrix* packed = &params->packed[side];
  packed->data_type = Type::Create<PackedScalar>();
  packed->sums_type = Type::Create<SumsType>();
  CreatePackedLayout(src.layout, packed->data_type, kernel_layout,
                     &packed->layout);
  packed->zero_point = Pack<PackedScalar, Scalar>(src.zero_point);
}

template <Path ThePath, typename LhsScalar, typename RhsScalar,
          typename DstScalar, typename Spec>
void PopulateTrMulParams(TrMulParams* params) {
  static_assert((ThePath & Path::kReference) == Path::kNone,
                "Path::kReference should not do TrMul");
  // The optimized code paths don't handle the full generality of Ruy's API.
  // Fall back to Path::kStandardCpp if necessary.
  bool fallback_to_standard_cpp = false;
  if (ThePath != Path::kStandardCpp) {
    // The optimized code paths currently only handle the case of all matrices
    // being column major.
    if (!IsColMajorTrMul(*params)) {
      fallback_to_standard_cpp = true;
    }
  }

  if (fallback_to_standard_cpp) {
    PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
                        Spec>(params);
    return;
  }

  using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
  using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
  using Kernel =
      Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, DstScalar, Spec>;
  using LhsKernelLayout = typename Kernel::LhsLayout;
  using RhsKernelLayout = typename Kernel::RhsLayout;

  params->path = ThePath;

  params->local_data_cache_size = Spec::local_data_cache_size();
  params->shared_data_cache_size = Spec::shared_data_cache_size();

  CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
      Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
  CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
      Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
  params->run_pack[Side::kLhs] =
      &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
  params->run_pack[Side::kRhs] =
      &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
  params->run_kernel =
      &RunKernel<ThePath, PackedLhsScalar, PackedRhsScalar, DstScalar, Spec>;

  return;
}

// PopulateTrMulParamsAllCompiledPaths calls into one of multiple
// instantiations of PopulateTrMulParams. For each bit that is set in
// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
// corresponding to that single bit. The call to PopulateTrMulParams is
// guarded by a runtime check that it is in fact the dynamically selected path.
//
// PopulateTrMulParamsAllCompiledPaths is implemented with template
// metaprogramming by mutual recursion between PathSearchCountdown and
// PathSearchCompiledPaths.
//
// PopulateTrMulParamsAllCompiledPaths is logically implementing the following
// computation:
//
// template <Path CompiledPaths>
// void PopulateTrMulParamsAllCompiledPaths(Path the_path,
//                                            TrMulParams* params) {
//   for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
//     Path current_path = static_cast<Path>(1 << bit);
//     if ((CompiledPaths & current_path) != Path::kNone) { // [2]
//       if (current_path == the_path) { // [3]
//         PopulateTrMulParams<current_path, ...>(the_path, params);
//         return;
//       }
//     }
//   }
// }
//
//
//
// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
// done in the recursion of PathSearchOnlyCompiledPaths.
// [2] - Done by PathSearchOnlyCompiledPaths's partial template
// specialization on InCompiledPaths. This is the check which necessitates
// doing the whole computation at C++ compile time.
// [3] - Done by the `if` in the main definition of
// PathSearchOnlyCompiledPaths.
//
// The template metaprogramming is necessary because:
// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
// compile-time constant.
// - PopulateTrMulParamsAllCompiledPaths must not instantiate
// inner loops for paths that are not in CompiledPaths, since that can result in
// bogus instantiations which cause a compile time failure.
template <Path CompiledPaths, int BitNumber, typename LhsScalar,
          typename RhsScalar, typename DstScalar, typename Spec>
struct PathSearchCountdown;

template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
          typename LhsScalar, typename RhsScalar, typename DstScalar,
          typename Spec>
struct PathSearchOnlyCompiledPaths {
  static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
  static void Search(Path the_path, TrMulParams* params) {
    if (kCurrentPath == the_path) {
      PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, DstScalar, Spec>(
          params);
      return;
    }
    PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
                        DstScalar, Spec>::Search(the_path, params);
  }
};

// Skip this iteration if CompiledPaths doesn't contain the specified path.
template <Path CompiledPaths, int BitNumber, typename LhsScalar,
          typename RhsScalar, typename DstScalar, typename Spec>
struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar,
                                   RhsScalar, DstScalar, Spec> {
  static void Search(Path the_path, TrMulParams* params) {
    PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
                        DstScalar, Spec>::Search(the_path, params);
  }
};

template <Path CompiledPaths, int BitNumber, typename LhsScalar,
          typename RhsScalar, typename DstScalar, typename Spec>
struct PathSearchCountdown {
  static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
  static void Search(Path the_path, TrMulParams* params) {
    PathSearchOnlyCompiledPaths<
        CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber,
        LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params);
  }
};

// Termination of the countdown. If the counter reaches -1, then we haven't
// found the specified path.
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename DstScalar, typename Spec>
struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, DstScalar,
                           Spec> {
  static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); }
};

template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename DstScalar, typename Spec>
void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
  return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar,
                             RhsScalar, DstScalar, Spec>::Search(the_path,
                                                                 params);
}

template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename DstScalar, typename Spec>
void CreateTrMulParams(const Matrix<LhsScalar>& lhs,
                       const Matrix<RhsScalar>& rhs, const Spec& spec,
                       Context* context, Matrix<DstScalar>* dst, Path the_path,
                       TrMulParams* params) {
  // Fill in the fields we already know.
  params->src[Side::kLhs] = ToDMatrix(lhs);
  params->src[Side::kRhs] = ToDMatrix(rhs);
  params->dst = ToDMatrix(*dst);
  params->spec = ToVoidPtr(&spec);

  // Create inner loops and packed matrices based on the Path.
  PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar,
                                      DstScalar, Spec>(the_path, params);
}

template <typename LhsScalar, typename RhsScalar, typename DstScalar,
          typename Spec>
void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
                  const Spec& spec, Matrix<DstScalar>* dst) {
  profiler::ScopeLabel label("ReferenceMul");
  for (int i = 0; i < lhs.layout.rows; i++) {
    for (int j = 0; j < rhs.layout.cols; j++) {
      using AccumScalar = typename Spec::AccumScalar;
      AccumScalar accum = 0;
      for (int k = 0; k < lhs.layout.cols; k++) {
        AccumScalar lhs_val = Element(lhs, i, k);
        AccumScalar rhs_val = Element(rhs, k, j);
        accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point);
      }
      if (spec.bias) {
        accum += spec.bias[i];
      }
      ApplyMultiplier(spec, i, &accum);
      accum += dst->zero_point;
      accum = std::min<AccumScalar>(accum, spec.clamp_max);
      accum = std::max<AccumScalar>(accum, spec.clamp_min);
      *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
    }
  }
}

// Compile-time dispatch to ReferenceMul. This allows us to statically ensure
// that there is no call to ReferenceMul in the user's binary.
template <bool ReferenceMulIsEnabled>
struct CompileTimeEnabledReferenceMul {
  template <typename LhsScalar, typename RhsScalar, typename DstScalar,
            typename Spec>
  static void Run(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
                  const Spec& spec, Matrix<DstScalar>* dst) {
    ReferenceMul(lhs, rhs, spec, dst);
  }
};

// When this partial specialization is chosen, it ensures that ReferenceMul
// is never compiled.
template <>
struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
  template <typename LhsScalar, typename RhsScalar, typename DstScalar,
            typename Spec>
  static void Run(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
                  const Spec& spec, Matrix<DstScalar>* dst) {
    RUY_DCHECK(false);
  }
};

inline void HandlePrepackedCaching(TrMulParams* params,
                                   const SidePair<bool>& cacheable,
                                   Context* context) {
  if (context->cache_policy == CachePolicy::kNoCache) {
    return;
  }

  if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) {
    // TODO(b/149304278) Cache on dst.cols <= selected kernel width.
    if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) {
      return;
    }
    PrepackedCache* prepacked_cache = context->GetPrepackedCache();
    auto cache_key = std::make_pair(reinterpret_cast<void*>(params->run_kernel),
                                    params->src[Side::kLhs].data);
    auto it = prepacked_cache->FindAndUpdate(cache_key);
    if (it != prepacked_cache->cend()) {
      params->packed[Side::kLhs].data = it->second.first.data;
      params->packed[Side::kLhs].sums = it->second.first.sums;
      params->is_prepacked[Side::kLhs] = true;
      return;
    }

    // Allocate the prepacked matrix.
    PrepackedMatrix prepacked_lhs;
    prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]);
    prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]);
    prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs);
    params->packed[Side::kLhs].data = prepacked_lhs.data;
    params->packed[Side::kLhs].sums = prepacked_lhs.sums;
    params->is_prepacked[Side::kLhs] = true;
    Tuning tuning = context->GetMainThreadTuning();
    params->RunPack(Side::kLhs, tuning, 0,
                    params->packed[Side::kLhs].layout.cols);
    prepacked_cache->Insert(cache_key, prepacked_lhs);
    return;
  }
}

template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename DstScalar, typename Spec>
void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
                 const Spec& spec, Context* context, Matrix<DstScalar>* dst) {
  static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path");
  static_assert((CompiledPaths & ~kAllPaths) == Path::kNone,
                "CompiledPaths must be a subset of ruy::kAllPaths");

  profiler::ScopeLabel mul_label("Mul");
  profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d",
                                            lhs.layout.rows, lhs.layout.cols,
                                            rhs.layout.cols);

  EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
  EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
                                dst->zero_point);
  EnforceDstSpecSupport<Spec>(spec, dst->zero_point);

  // This should be a constant, for a given machine and CompiledPaths.
  // There is a back door to override it for testing, but in production it will
  // always be the "best" Path. I.e. the one with the newest SIMD instructions
  // available on the present machine, and avoiding Path::kReference unless
  // no other path is compiled.
  //
  // Unfortunately, it is not a *static* constant, since it depends on runtime
  // detection of the available SIMD instructions.
  Path the_path = context->GetPathToTake<CompiledPaths>();

  // Production code should probably never execute Path::kReference.
  // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if
  // that's what we need to do, then get it out of the way before going down the
  // TrMul path.
  if (the_path == Path::kReference) {
    constexpr bool ReferenceMulIsEnabled =
        (CompiledPaths & Path::kReference) != Path::kNone;
    CompileTimeEnabledReferenceMul<ReferenceMulIsEnabled>::Run(lhs, rhs, spec,
                                                               dst);
    return;
  }

  // As described in the comment at the top of this file, Ruy internally
  // converts Mul into TrMul. We handle that here.
  //
  // This is Ruy's main code path.
  constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
  Matrix<LhsScalar> transposed_lhs(lhs);
  Transpose(&transposed_lhs);
  TrMulParams params;
  CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
                                        the_path, &params);
  SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
  HandlePrepackedCaching(&params, cacheable, context);
  TrMul(&params, context);
}

}  // namespace ruy

#endif  // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_