diff options
Diffstat (limited to 'ruy/kernel_x86.h')
-rw-r--r-- | ruy/kernel_x86.h | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index d2045de..051c894 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -111,6 +111,24 @@ struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, } }; +template <typename DstScalar> +struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, + DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + Kernel8bitAvx2(params); + } +}; + void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); |