Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/ruy
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2020-03-28 04:58:51 +0300
committerBenoit Jacob <benoitjacob@google.com>2020-03-30 23:51:39 +0300
commitf7ea583082c670103fb2cebd6035b944c71d64c4 (patch)
tree4a58b4b3a210fc78776e16cf19b46e671714ac4a /ruy
parent299a33a5c2affb88c75726c77be6dd4491418b17 (diff)
Move ruy's code to a ruy/ subdirectory.
The motivation is that having source files in the repository root runs into a number of corner cases with copybara setups and with external CMake build systems, so enclosing all code in ruy/ avoids that while generally making our setup much more similar to that of other related projects (TensorFlow, IREE). PiperOrigin-RevId: 303448881
Diffstat (limited to 'ruy')
-rw-r--r--ruy/BUILD954
-rw-r--r--ruy/allocator.cc51
-rw-r--r--ruy/allocator.h185
-rw-r--r--ruy/allocator_test.cc103
-rw-r--r--ruy/benchmark.cc196
-rw-r--r--ruy/block_map.cc486
-rw-r--r--ruy/block_map.h161
-rw-r--r--ruy/block_map_test.cc263
-rw-r--r--ruy/blocking_counter.cc49
-rw-r--r--ruy/blocking_counter.h62
-rw-r--r--ruy/build_defs.bzl54
-rw-r--r--ruy/build_defs.bzl.opensource40
-rw-r--r--ruy/check_macros.h138
-rw-r--r--ruy/check_macros_test.cc153
-rw-r--r--ruy/common.h73
-rw-r--r--ruy/context.cc109
-rw-r--r--ruy/context.h109
-rw-r--r--ruy/context_test.cc63
-rw-r--r--ruy/cpu_cache_size.h81
-rw-r--r--ruy/detect_arm.cc73
-rw-r--r--ruy/detect_arm.h29
-rw-r--r--ruy/detect_x86.cc101
-rw-r--r--ruy/detect_x86.h49
-rw-r--r--ruy/dispatch.h482
-rw-r--r--ruy/example.cc136
-rw-r--r--ruy/example_advanced.cc83
-rw-r--r--ruy/have_built_path_for.h32
-rw-r--r--ruy/have_built_path_for_avx2.cc35
-rw-r--r--ruy/have_built_path_for_avx512.cc35
-rw-r--r--ruy/have_built_path_for_avxvnni.cc39
-rw-r--r--ruy/have_built_path_for_sse42.cc39
-rw-r--r--ruy/internal_matrix.h388
-rw-r--r--ruy/kernel.h31
-rw-r--r--ruy/kernel_arm.h211
-rw-r--r--ruy/kernel_arm32.cc2499
-rw-r--r--ruy/kernel_arm64.cc7835
-rw-r--r--ruy/kernel_avx2.cc1664
-rw-r--r--ruy/kernel_avx512.cc1820
-rw-r--r--ruy/kernel_avxvnni.cc435
-rw-r--r--ruy/kernel_common.h481
-rw-r--r--ruy/kernel_sse42.cc428
-rw-r--r--ruy/kernel_x86.h222
-rw-r--r--ruy/matrix.h182
-rw-r--r--ruy/opt_set.h51
-rw-r--r--ruy/pack.h98
-rw-r--r--ruy/pack_arm.cc1936
-rw-r--r--ruy/pack_arm.h497
-rw-r--r--ruy/pack_avx2.cc816
-rw-r--r--ruy/pack_avx512.cc693
-rw-r--r--ruy/pack_avxvnni.cc478
-rw-r--r--ruy/pack_common.h246
-rw-r--r--ruy/pack_sse42.cc471
-rw-r--r--ruy/pack_x86.h461
-rw-r--r--ruy/path.h162
-rw-r--r--ruy/platform.h156
-rw-r--r--ruy/pmu.cc281
-rw-r--r--ruy/pmu.h44
-rw-r--r--ruy/prepack.h108
-rw-r--r--ruy/prepacked_cache.cc82
-rw-r--r--ruy/prepacked_cache.h130
-rw-r--r--ruy/prepacked_cache_test.cc210
-rw-r--r--ruy/profiler/BUILD52
-rw-r--r--ruy/profiler/README.md149
-rw-r--r--ruy/profiler/instrumentation.cc130
-rw-r--r--ruy/profiler/instrumentation.h203
-rw-r--r--ruy/profiler/profiler.cc109
-rw-r--r--ruy/profiler/profiler.h106
-rw-r--r--ruy/profiler/test.cc167
-rw-r--r--ruy/profiler/test_instrumented_library.cc59
-rw-r--r--ruy/profiler/test_instrumented_library.h23
-rw-r--r--ruy/profiler/treeview.cc248
-rw-r--r--ruy/profiler/treeview.h130
-rw-r--r--ruy/ruy.h42
-rw-r--r--ruy/ruy_advanced.h69
-rw-r--r--ruy/ruy_test.bzl34
-rw-r--r--ruy/ruy_test_ext.bzl19
-rw-r--r--ruy/ruy_test_ext.bzl.opensource7
-rw-r--r--ruy/side_pair.h64
-rw-r--r--ruy/size_util.h93
-rw-r--r--ruy/size_util_test.cc101
-rw-r--r--ruy/spec.h118
-rw-r--r--ruy/test.h2125
-rw-r--r--ruy/test_fast.cc110
-rw-r--r--ruy/test_slow.cc71
-rw-r--r--ruy/test_special_specs.cc163
-rw-r--r--ruy/thread_pool.cc200
-rw-r--r--ruy/thread_pool.h102
-rw-r--r--ruy/time.h81
-rw-r--r--ruy/trace.cc325
-rw-r--r--ruy/trace.h73
-rw-r--r--ruy/trmul.cc401
-rw-r--r--ruy/trmul.h38
-rw-r--r--ruy/trmul_params.h67
-rw-r--r--ruy/tune.cc161
-rw-r--r--ruy/tune.h163
-rw-r--r--ruy/tune_test.cc53
-rw-r--r--ruy/tune_tool.cc56
-rw-r--r--ruy/wait.cc69
-rw-r--r--ruy/wait.h73
-rw-r--r--ruy/wait_test.cc117
100 files changed, 33950 insertions, 0 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
new file mode 100644
index 0000000..0b19193
--- /dev/null
+++ b/ruy/BUILD
@@ -0,0 +1,954 @@
+# Ruy is not BLAS
+
+load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42")
+load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps")
+load(":ruy_test.bzl", "ruy_benchmark", "ruy_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
+config_setting(
+ name = "armeabi-v7a",
+ values = {"cpu": "armeabi-v7a"},
+)
+
+config_setting(
+ name = "x86_64",
+ values = {"cpu": "k8"},
+)
+
+config_setting(
+ name = "optimized",
+ values = {
+ "compilation_mode": "opt",
+ },
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "platform",
+ hdrs = ["platform.h"],
+ copts = ruy_copts_base(),
+)
+
+cc_library(
+ name = "check_macros",
+ hdrs = ["check_macros.h"],
+ copts = ruy_copts_base(),
+)
+
+cc_test(
+ name = "check_macros_test",
+ srcs = ["check_macros_test.cc"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "opt_set",
+ hdrs = ["opt_set.h"],
+ copts = ruy_copts_base(),
+)
+
+cc_library(
+ name = "time",
+ hdrs = ["time.h"],
+ copts = ruy_copts_base(),
+)
+
+cc_library(
+ name = "wait",
+ srcs = ["wait.cc"],
+ hdrs = ["wait.h"],
+ copts = ruy_copts_base(),
+ deps = [":time"],
+)
+
+cc_test(
+ name = "wait_test",
+ srcs = ["wait_test.cc"],
+ deps = [
+ ":platform",
+ ":wait",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "size_util",
+ hdrs = ["size_util.h"],
+ copts = ruy_copts_base(),
+ deps = [":check_macros"],
+)
+
+cc_test(
+ name = "size_util_test",
+ srcs = ["size_util_test.cc"],
+ deps = [
+ ":size_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "tune",
+ srcs = [
+ "tune.cc",
+ ],
+ hdrs = [
+ "tune.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":opt_set",
+ ":platform",
+ ":time",
+ ],
+)
+
+cc_library(
+ name = "prepacked_cache",
+ srcs = [
+ "prepacked_cache.cc",
+ ],
+ hdrs = [
+ "prepacked_cache.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":allocator",
+ ":matrix",
+ ":opt_set",
+ ":platform",
+ ":time",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_test(
+ name = "tune_test",
+ srcs = ["tune_test.cc"],
+ deps = [
+ ":tune",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "prepacked_cache_test",
+ srcs = ["prepacked_cache_test.cc"],
+ deps = [
+ ":prepacked_cache",
+ ":ruy",
+ ":time",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "tune_tool",
+ srcs = ["tune_tool.cc"],
+ deps = [
+ ":tune",
+ ],
+)
+
+cc_library(
+ name = "allocator",
+ srcs = [
+ "allocator.cc",
+ ],
+ hdrs = [
+ "allocator.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":size_util",
+ ],
+)
+
+cc_test(
+ name = "allocator_test",
+ srcs = ["allocator_test.cc"],
+ deps = [
+ ":allocator",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "side_pair",
+ hdrs = ["side_pair.h"],
+ copts = ruy_copts_base(),
+ deps = [":check_macros"],
+)
+
+cc_library(
+ name = "block_map",
+ srcs = [
+ "block_map.cc",
+ ],
+ hdrs = [
+ "block_map.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":opt_set",
+ ":path",
+ ":side_pair",
+ ":size_util",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_test(
+ name = "block_map_test",
+ srcs = ["block_map_test.cc"],
+ deps = [
+ ":block_map",
+ ":cpu_cache_size",
+ ":path",
+ ":side_pair",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "blocking_counter",
+ srcs = [
+ "blocking_counter.cc",
+ ],
+ hdrs = [
+ "blocking_counter.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":wait",
+ ],
+)
+
+cc_library(
+ name = "thread_pool",
+ srcs = [
+ "thread_pool.cc",
+ ],
+ hdrs = [
+ "thread_pool.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":blocking_counter",
+ ":check_macros",
+ ":wait",
+ ],
+)
+
+cc_library(
+ name = "detect_arm",
+ srcs = [
+ "detect_arm.cc",
+ ],
+ hdrs = [
+ "detect_arm.h",
+ ],
+ copts = ruy_copts_base(),
+)
+
+cc_library(
+ name = "detect_x86",
+ srcs = [
+ "detect_x86.cc",
+ ],
+ hdrs = [
+ "detect_x86.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "path",
+ hdrs = ["path.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":platform",
+ ":size_util",
+ ],
+)
+
+cc_library(
+ name = "cpu_cache_size",
+ hdrs = ["cpu_cache_size.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":path",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "trace",
+ srcs = [
+ "trace.cc",
+ ],
+ hdrs = [
+ "trace.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":block_map",
+ ":check_macros",
+ ":side_pair",
+ ":time",
+ ],
+)
+
+cc_library(
+ name = "matrix",
+ hdrs = ["matrix.h"],
+ copts = ruy_copts_base(),
+ deps = [":check_macros"],
+)
+
+cc_library(
+ name = "spec",
+ hdrs = ["spec.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":cpu_cache_size",
+ ":matrix",
+ ],
+)
+
+cc_library(
+ name = "internal_matrix",
+ hdrs = ["internal_matrix.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":matrix",
+ ":size_util",
+ ],
+)
+
+cc_library(
+ name = "common",
+ hdrs = [
+ "common.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "kernel_common",
+ hdrs = [
+ "kernel.h",
+ "kernel_arm.h",
+ "kernel_common.h",
+ "kernel_x86.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":internal_matrix",
+ ":matrix",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ":size_util",
+ ":spec",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_common",
+ hdrs = [
+ "pack.h",
+ "pack_arm.h",
+ "pack_common.h",
+ "pack_x86.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":internal_matrix",
+ ":matrix",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "kernel_arm",
+ srcs = [
+ "kernel_arm32.cc",
+ "kernel_arm64.cc",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":common",
+ ":kernel_common",
+ ":opt_set",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_arm",
+ srcs = [
+ "pack_arm.cc",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":common",
+ ":opt_set",
+ ":pack_common",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+# AVX-512 compilation units.
+#
+# These must use the same compiler options.
+RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake()
+
+cc_library(
+ name = "kernel_avx512",
+ srcs = [
+ "kernel_avx512.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX512,
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":opt_set",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avx512",
+ srcs = [
+ "pack_avx512.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX512,
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avx512",
+ srcs = [
+ "have_built_path_for_avx512.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX512,
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+# End: AVX-512 compilation units.
+
+# AVX2 compilation units.
+#
+# These must use the same compiler options.
+RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2()
+
+cc_library(
+ name = "kernel_avx2",
+ srcs = [
+ "kernel_avx2.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX2,
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":opt_set",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avx2",
+ srcs = [
+ "pack_avx2.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX2,
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avx2",
+ srcs = [
+ "have_built_path_for_avx2.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX2,
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+# End: AVX2 compilation units.
+
+# SSE42 compilation units.
+#
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# These must use the same compiler options.
+RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42()
+
+cc_library(
+ name = "kernel_sse42",
+ srcs = [
+ "kernel_sse42.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_SSE42,
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":opt_set",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_sse42",
+ srcs = [
+ "pack_sse42.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_SSE42,
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_sse42",
+ srcs = [
+ "have_built_path_for_sse42.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_SSE42,
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+# End: SSE42 compilation units.
+
+# AVX-VNNI compilation units.
+#
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# These must use the same compiler options.
+RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni()
+
+cc_library(
+ name = "kernel_avxvnni",
+ srcs = [
+ "kernel_avxvnni.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX_VNNI,
+ deps = [
+ ":check_macros",
+ ":kernel_common",
+ ":opt_set",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack_avxvnni",
+ srcs = [
+ "pack_avxvnni.cc",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX_VNNI,
+ deps = [
+ ":check_macros",
+ ":matrix",
+ ":opt_set",
+ ":pack_common",
+ ":path",
+ ":platform",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for_avxvnni",
+ srcs = [
+ "have_built_path_for_avxvnni.cc",
+ ],
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ copts = RUY_COPTS_BUILT_FOR_AVX_VNNI,
+ deps = [
+ ":opt_set",
+ ":platform",
+ ],
+)
+# End: AVX-VNNI compilation units.
+
+cc_library(
+ name = "kernel",
+ hdrs = [
+ "kernel.h",
+ "kernel_common.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":internal_matrix",
+ ":kernel_arm", # fixdeps: keep
+ ":kernel_avx2", # fixdeps: keep
+ ":kernel_avx512", # fixdeps: keep
+ ":kernel_avxvnni", # fixdeps: keep
+ ":kernel_common",
+ ":kernel_sse42", # fixdeps: keep
+ ":matrix",
+ ":opt_set",
+ ":path",
+ ":platform",
+ ":side_pair",
+ ":size_util",
+ ":spec",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "pack",
+ hdrs = [
+ "pack.h",
+ "pack_common.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":internal_matrix",
+ ":matrix",
+ ":opt_set",
+ ":pack_arm", # fixdeps: keep
+ ":pack_avx2", # fixdeps: keep
+ ":pack_avx512", # fixdeps: keep
+ ":pack_avxvnni", # fixdeps: keep
+ ":pack_common",
+ ":pack_sse42", # fixdeps: keep
+ ":path",
+ ":platform",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+cc_library(
+ name = "have_built_path_for",
+ hdrs = [
+ "have_built_path_for.h",
+ ],
+ deps = [
+ ":have_built_path_for_avx2",
+ ":have_built_path_for_avx512",
+ ":have_built_path_for_avxvnni",
+ ":have_built_path_for_sse42",
+ ":platform",
+ ],
+)
+
+cc_library(
+ name = "context",
+ srcs = [
+ "context.cc",
+ ],
+ hdrs = [
+ "context.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":allocator",
+ ":check_macros",
+ ":detect_arm",
+ ":detect_x86",
+ ":have_built_path_for",
+ ":path",
+ ":platform",
+ ":prepacked_cache",
+ ":thread_pool",
+ ":trace",
+ ":tune",
+ ],
+)
+
+cc_test(
+ name = "context_test",
+ srcs = ["context_test.cc"],
+ deps = [
+ ":context",
+ ":path",
+ ":platform",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "trmul_params",
+ hdrs = ["trmul_params.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":internal_matrix",
+ ":side_pair",
+ ":tune",
+ ],
+)
+
+cc_library(
+ name = "trmul",
+ srcs = ["trmul.cc"],
+ hdrs = ["trmul.h"],
+ copts = ruy_copts_base(),
+ deps = [
+ ":allocator",
+ ":block_map",
+ ":check_macros",
+ ":common",
+ ":context",
+ ":internal_matrix",
+ ":matrix",
+ ":opt_set",
+ ":side_pair",
+ ":size_util",
+ ":spec",
+ ":thread_pool",
+ ":trace",
+ ":trmul_params",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+# The main library.
+cc_library(
+ name = "ruy",
+ srcs = [
+ "dispatch.h",
+ "prepack.h",
+ ],
+ hdrs = [
+ "ruy.h",
+ "ruy_advanced.h",
+ ],
+ copts = ruy_copts_base(),
+ deps = [
+ ":check_macros",
+ ":common",
+ ":context",
+ ":internal_matrix",
+ ":kernel",
+ ":matrix",
+ ":opt_set",
+ ":pack",
+ ":path",
+ ":prepacked_cache",
+ ":side_pair",
+ ":size_util",
+ ":spec",
+ ":trmul",
+ ":trmul_params",
+ ":tune",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+# Usage examples.
+cc_binary(
+ name = "example",
+ srcs = ["example.cc"],
+ deps = [":ruy"],
+)
+
+# Usage examples of the advanced API.
+cc_binary(
+ name = "example_advanced",
+ srcs = ["example_advanced.cc"],
+ deps = [":ruy"],
+)
+
+# Small library to query PMU counters, for benchmark only
+cc_library(
+ name = "pmu",
+ testonly = True,
+ srcs = ["pmu.cc"],
+ hdrs = ["pmu.h"],
+ copts = ruy_copts_base(),
+ deps = [":check_macros"],
+)
+
+# Testing framework.
+cc_library(
+ name = "test_lib",
+ testonly = True,
+ hdrs = ["test.h"],
+ copts = ruy_copts_base(),
+ # need defines, not copts, because it's controlling a header, test.h
+ defines = ruy_test_ext_defines(),
+ linkopts = select({
+ ":windows": [],
+ "//conditions:default": ["-lm"],
+ }),
+ deps = [
+ ":matrix",
+ ":pmu",
+ ":ruy",
+ ":spec",
+ ":time",
+ "@com_google_googletest//:gtest",
+ ":platform",
+ "//ruy/profiler:profiler",
+ ] + ruy_test_ext_deps(),
+)
+
+ruy_benchmark(
+ name = "benchmark",
+ srcs = ["benchmark.cc"],
+ copts = ruy_copts_base(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ],
+ deps = [
+ "//ruy:test_lib",
+ "//ruy/profiler:instrumentation",
+ ],
+)
+
+ruy_test(
+ name = "test_fast",
+ srcs = ["test_fast.cc"],
+ copts = ruy_copts_base(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("f64", "f32", "f64", "f32"),
+ ("f32", "f64", "f64", "f64"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("i8", "u8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ("i8", "u8", "i32", "i32"),
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "//ruy:test_lib",
+ ],
+)
+
+ruy_test(
+ name = "test_slow",
+ srcs = ["test_slow.cc"],
+ copts = ruy_copts_base(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("u8", "u8", "i32", "u8"),
+ ("i8", "i8", "i32", "i8"),
+ ("u8", "u8", "i32", "i16"),
+ ("i8", "i8", "i32", "i32"),
+ ],
+ tags = ["slow"],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "//ruy:test_lib",
+ ],
+)
+
+ruy_test(
+ name = "test_special_specs",
+ srcs = ["test_special_specs.cc"],
+ copts = ruy_copts_base(),
+ lhs_rhs_accum_dst = [
+ ("f32", "f32", "f32", "f32"),
+ ("u8", "u8", "i32", "u8"),
+ ("u8", "u8", "i32", "i16"),
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ "//ruy:test_lib",
+ ],
+)
diff --git a/ruy/allocator.cc b/ruy/allocator.cc
new file mode 100644
index 0000000..d8fb738
--- /dev/null
+++ b/ruy/allocator.cc
@@ -0,0 +1,51 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/allocator.h"
+
+#include <cstdint>
+#include <cstdlib>
+
+#ifdef _WIN32
+#include <malloc.h>
+#endif
+
+namespace ruy {
+
+namespace detail {
+
+void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) {
+#ifdef _WIN32
+ return _aligned_malloc(num_bytes, kMinimumBlockAlignment);
+#else
+ void *ptr;
+ if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) {
+ return nullptr;
+ }
+ return ptr;
+#endif
+}
+
+void SystemAlignedFree(void *ptr) {
+#ifdef _WIN32
+ _aligned_free(ptr);
+#else
+ free(ptr);
+#endif
+}
+
+} // namespace detail
+
+} // namespace ruy
diff --git a/ruy/allocator.h b/ruy/allocator.h
new file mode 100644
index 0000000..b0379b1
--- /dev/null
+++ b/ruy/allocator.h
@@ -0,0 +1,185 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "ruy/check_macros.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+namespace detail {
+
+inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) {
+ RUY_DCHECK(p);
+ std::uintptr_t addr = reinterpret_cast<std::uintptr_t>(p) + offset;
+ return reinterpret_cast<void*>(addr);
+}
+
+// Minimum alignment for blocks.
+//
+// Considerations:
+// - This needs to be at least the alignment of any usual data type.
+// - It's useful that this is at least the size of a cache line to limit
+// possible cache side effects (if only on performance behavior).
+// - It's useful that this is at least the size of SIMD registers, as
+// some SIMD instruction sets have at least performance behavior
+// differences (e.g. NEON) or even different requirements (e.g. SSE)
+// based on that.
+// - It's useful that this is at least the size of an "exclusive reservation
+// granule" on ARM, meaning that if we use this Allocator to allocate
+// an atomic variable, there will be no side effects from other things
+// contending for exclusive/atomic memory accesses to it. While the
+// ARM reference manual mentions that this granule size may be as large
+// as 2048 bytes, in practice we observe it to be 64 bytes. It can
+// be queried cheaply, at runtime, from userspace, if needed.
+static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64;
+
+// Primitive allocation functions obtaining aligned memory from the
+// operating system.
+void* SystemAlignedAlloc(std::ptrdiff_t num_bytes);
+void SystemAlignedFree(void* ptr);
+
+// Specialized allocator designed to converge to a steady-state where all
+// allocations are bump-ptr allocations from an already-allocated buffer.
+//
+// To support these constraints, this allocator only supports two
+// operations.
+// - AllocateAlignedBytes: allocates a pointer to storage of a specified
+// size, which must be aligned to kMinimumBlockAlignment.
+// - FreeAll: frees all previous allocations (but retains the internal
+// buffer to minimize future calls into the system allocator).
+//
+// This class is specialized for supporting just those two operations
+// under this specific steady-state usage pattern. Extending this class
+// with new allocation interfaces that don't fit that pattern is probably not
+// the right choice. Instead, build a new class on top of
+// SystemAlignedAlloc/SystemAlignedFree.
+//
+// All operations happen on aligned blocks for simplicity.
+class AlignedAllocator {
+ public:
+ void operator=(const AlignedAllocator&) = delete;
+ ~AlignedAllocator() {
+ FreeAll();
+ SystemAlignedFree(ptr_);
+ }
+
+ void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) {
+ RUY_DCHECK_GT(num_bytes, 0);
+ RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0);
+ if (void* p = AllocateFast(num_bytes)) {
+ return p;
+ }
+ return AllocateSlow(num_bytes);
+ }
+
+ void FreeAll() {
+ current_ = 0;
+ if (fallback_blocks_.empty()) {
+ return;
+ }
+
+ // No rounding-up of the size means linear instead of logarithmic
+ // bound on the number of allocation in some worst-case calling patterns.
+ // This is considered worth it because minimizing memory usage is important
+ // and actual calling patterns in applications that we care about still
+ // reach the no-further-allocations steady state in a small finite number
+ // of iterations.
+ std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_;
+ SystemAlignedFree(ptr_);
+ ptr_ = SystemAlignedAlloc(new_size);
+ size_ = new_size;
+
+ for (void* p : fallback_blocks_) {
+ SystemAlignedFree(p);
+ }
+ fallback_blocks_.clear();
+ fallback_blocks_total_size_ = 0;
+ }
+
+ private:
+ void* AllocateFast(std::ptrdiff_t num_bytes) {
+ if (current_ + num_bytes > size_) {
+ return nullptr;
+ }
+ void* ret = VoidPtrAdd(ptr_, current_);
+ current_ += num_bytes;
+ return ret;
+ }
+
+ void* AllocateSlow(std::ptrdiff_t num_bytes) {
+ void* p = SystemAlignedAlloc(num_bytes);
+ fallback_blocks_total_size_ += num_bytes;
+ fallback_blocks_.push_back(p);
+ return p;
+ }
+
+ // Theory of operation:
+ //
+ // - ptr_, current_, and size_ implement a basic bump-ptr allocator.
+ //
+ // - in AllocateAlignedBytes, the fast path is just a bump-ptr
+ // allocation. If our bump-ptr allocator doesn't have enough space for an
+ // allocation, then we allocate a block from the system allocator to
+ // service the allocation request. We save that block in fallback_blocks_
+ // and track the total size of the fallback blocks in
+ // fallback_blocks_total_size_.
+ //
+ // - in FreeAll, the fast path just resets the bump-ptr allocator. If
+ // there are any fallback blocks, we free them and reallocate the
+ // bump-ptr allocator's buffer so that the next sequence of allocations
+ // will hopefully not need any fallback blocks.
+ void* ptr_ = nullptr;
+ std::ptrdiff_t current_ = 0;
+ std::ptrdiff_t size_ = 0;
+ std::vector<void*> fallback_blocks_;
+ std::ptrdiff_t fallback_blocks_total_size_ = 0;
+};
+
+} // namespace detail
+
+// The main Allocator class, with a convenient interface for allocating a
+// typed buffer.
+class Allocator {
+ public:
+ void* AllocateBytes(std::ptrdiff_t num_bytes) {
+ if (num_bytes == 0) {
+ return nullptr;
+ }
+ return aligned.AllocateAlignedBytes(
+ round_up_pot(num_bytes, detail::kMinimumBlockAlignment));
+ }
+ template <typename Pointer>
+ void Allocate(std::ptrdiff_t count, Pointer* out) {
+ using T = typename std::pointer_traits<Pointer>::element_type;
+ *out = static_cast<T*>(AllocateBytes(count * sizeof(T)));
+ }
+
+ void FreeAll() { aligned.FreeAll(); }
+
+ private:
+ detail::AlignedAllocator aligned;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_
diff --git a/ruy/allocator_test.cc b/ruy/allocator_test.cc
new file mode 100644
index 0000000..7f46a66
--- /dev/null
+++ b/ruy/allocator_test.cc
@@ -0,0 +1,103 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/allocator.h"
+
+#include "testing/base/public/gunit.h"
+
+namespace ruy {
+namespace {
+
+TEST(AllocatorTest, ReturnsValidMemory) {
+ Allocator allocator;
+ int *p;
+ allocator.Allocate(1, &p);
+ ASSERT_NE(p, nullptr);
+
+ // If this is bogus memory, ASan will cause this test to fail.
+ *p = 42;
+
+ allocator.FreeAll();
+}
+
+TEST(AllocatorTest, NoLeak) {
+ Allocator allocator;
+ // Allocate and free some ridiculously large total amount of memory, so
+ // that a leak will hopefully cause some sort of resource exhaustion.
+ //
+ // Despite the large number of allocations, this test is actually quite
+ // fast, since our fast-path allocation logic is very fast.
+ constexpr int kNumAllocations = 100 * 1024;
+ constexpr int kAllocationSize = 1024 * 1024;
+ for (int i = 0; i < kNumAllocations; i++) {
+ char *p;
+ allocator.Allocate(kAllocationSize, &p);
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, IncreasingSizes) {
+ Allocator allocator;
+ // Allocate sizes that increase by small amounts across FreeAll calls.
+ for (int i = 1; i < 100 * 1024; i++) {
+ char *p;
+ allocator.Allocate(i, &p);
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, ManySmallAllocations) {
+ Allocator allocator;
+ // Allocate many small allocations between FreeAll calls.
+ for (int i = 0; i < 10 * 1024; i += 100) {
+ for (int j = 0; j < i; j++) {
+ char *p;
+ allocator.Allocate(1, &p);
+ }
+ allocator.FreeAll();
+ }
+}
+
+TEST(AllocatorTest, DestructorHandlesMainBumpPtr) {
+ // This is a white-box test.
+ Allocator allocator;
+ allocator.AllocateBytes(1);
+ allocator.FreeAll();
+ // After the call to FreeAll, the allocator will consolidate all of the memory
+ // into the main bump-ptr allocator's block, which we then expect to be freed
+ // in the destructor.
+ //
+ // We have no test assertions -- we primarily expect that this trigger a leak
+ // checker and cause the test to fail.
+}
+
+TEST(AllocatorTest, DestructorHandlesFallbackBlocks) {
+ // This is a white-box test.
+ Allocator allocator;
+ // Since we just created the allocator, this will allocate a fallback block,
+ // which we then expect to be freed in the destructor.
+ //
+ // We have no test assertions -- we primarily expect that this trigger a leak
+ // checker and cause the test to fail.
+ allocator.AllocateBytes(1);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc
new file mode 100644
index 0000000..6ce0b32
--- /dev/null
+++ b/ruy/benchmark.cc
@@ -0,0 +1,196 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+using TestSetType =
+ TestSet<LhsScalar, RhsScalar, BasicSpec<AccumScalar, DstScalar>>;
+
+struct BenchmarkShape {
+ int rows;
+ int depth;
+ int cols;
+ int symm_lhs;
+ int symm_rhs;
+};
+
+template <typename TestSetType>
+std::vector<std::unique_ptr<TestResult<DstScalar>>> BenchmarkRCC(
+ const BenchmarkShape& shape) {
+ TestSetType test_set;
+ test_set.rows = shape.rows;
+ test_set.depth = shape.depth;
+ test_set.cols = shape.cols;
+ test_set.lhs_order = Order::kRowMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kPackedLinear;
+ test_set.benchmark = true;
+ const int asymmetry_lhs = shape.symm_lhs ? 0 : 1;
+ const int asymmetry_rhs = shape.symm_rhs ? 0 : 1;
+ test_set.lhs_zero_point = SymmetricZeroPoint<LhsScalar>() + asymmetry_lhs;
+ test_set.rhs_zero_point = SymmetricZeroPoint<RhsScalar>() + asymmetry_rhs;
+ test_set.use_specified_zero_points = true;
+ test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL");
+ test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS");
+ test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS");
+ test_set.Run();
+ return std::move(test_set.results);
+}
+
+std::vector<int> ParseCommaSeparatedInts(
+ const std::string& comma_separated_ints) {
+ std::vector<int> result;
+ for (std::size_t pos = 0; pos < comma_separated_ints.size();) {
+ std::size_t delim_pos = comma_separated_ints.find(',', pos);
+ if (delim_pos == std::string::npos) {
+ delim_pos = comma_separated_ints.size();
+ }
+ result.push_back(
+ std::stoi(comma_separated_ints.substr(pos, delim_pos - pos)));
+ pos = delim_pos + 1;
+ }
+ return result;
+}
+
+void Benchmark() {
+ const bool symm_lhs = std::is_floating_point<LhsScalar>::value ||
+ GetBoolEnvVarOrFalse("SYMM_LHS");
+ const bool symm_rhs = std::is_floating_point<RhsScalar>::value ||
+ GetBoolEnvVarOrFalse("SYMM_RHS");
+ const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") ||
+ GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST");
+ const int explicit_rows = GetIntEnvVarOrZero("ROWS");
+ const int explicit_cols = GetIntEnvVarOrZero("COLS");
+ const int explicit_depth = GetIntEnvVarOrZero("DEPTH");
+
+ std::vector<BenchmarkShape> shapes;
+
+ if (benchmark_cubic) {
+ std::vector<int> sizes;
+ const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST");
+ if (benchmark_cubic_list_env) {
+ sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env);
+ } else {
+ // Often 8 is used for this multiplier, but to check teeny sizes one can
+ // use 1.
+ static constexpr int cubic_size_multiplier = 8;
+ for (int i = 2 * cubic_size_multiplier;
+ i <= (512 * cubic_size_multiplier); i *= 2) {
+ sizes.push_back(i);
+ if (i < (512 * cubic_size_multiplier)) {
+ sizes.push_back(i * 3 / 2);
+ }
+ }
+ }
+ for (int i : sizes) {
+ BenchmarkShape shape;
+ // Even in cubic mode, one may still override an individual dimension
+ // to allow testing a batch of rectangular sizes.
+ shape.rows = explicit_rows ? explicit_rows : i;
+ shape.cols = explicit_cols ? explicit_cols : i;
+ shape.depth = explicit_depth ? explicit_depth : i;
+ shape.symm_lhs = symm_lhs;
+ shape.symm_rhs = symm_rhs;
+ shapes.push_back(shape);
+ }
+ } else {
+ BenchmarkShape shape;
+ shape.rows = explicit_rows;
+ shape.cols = explicit_cols;
+ shape.depth = explicit_depth;
+ if (!shape.rows || !shape.depth || !shape.cols) {
+ fprintf(stderr,
+ "Please specify positive sizes with these env vars: ROWS, DEPTH, "
+ "COLS.\n");
+ exit(1);
+ }
+ shape.symm_lhs = symm_lhs;
+ shape.symm_rhs = symm_rhs;
+ shapes.push_back(shape);
+ }
+
+ for (int i = 0; i < shapes.size(); i++) {
+ const auto& shape = shapes[i];
+ const auto& results = BenchmarkRCC<TestSetType>(shape);
+ if (i == 0) {
+ if (benchmark_cubic) {
+ printf("size");
+ for (const auto& result : results) {
+ if (results.size() > 1) {
+ printf(",%s:Gop/s", PathName(*result).c_str());
+ } else {
+ printf(",Gop/s");
+ }
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(
+ ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill,"
+ "mispred,frontend_stall,backend_stall");
+ }
+ }
+ printf("\n");
+ } else {
+ printf("path,shape,Gop/s\n");
+ }
+ fflush(stdout);
+ }
+ if (benchmark_cubic) {
+ printf("%d", shape.rows);
+ for (const auto& result : results) {
+ printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth /
+ result->latency);
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g",
+ result->l1_refill_rate, result->l2_refill_rate,
+ result->l3_refill_rate, result->l1tlb_refill_rate,
+ result->l2tlb_refill_rate, result->mispred_rate,
+ result->frontend_stall_rate, result->backend_stall_rate);
+ }
+ }
+ printf("\n");
+ fflush(stdout);
+ } else {
+ for (const auto& result : results) {
+ printf(
+ "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows,
+ shape.depth, shape.cols,
+ 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency);
+ if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) {
+ printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g",
+ result->l1_refill_rate, result->l2_refill_rate,
+ result->l3_refill_rate, result->l1tlb_refill_rate,
+ result->l2tlb_refill_rate, result->mispred_rate,
+ result->frontend_stall_rate, result->backend_stall_rate);
+ }
+ printf("\n");
+ }
+ fflush(stdout);
+ }
+ }
+}
+
+} // namespace ruy
+
+int main() { ruy::Benchmark(); }
diff --git a/ruy/block_map.cc b/ruy/block_map.cc
new file mode 100644
index 0000000..e1e6166
--- /dev/null
+++ b/ruy/block_map.cc
@@ -0,0 +1,486 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/block_map.h"
+
+#include <algorithm>
+#include <cstdint>
+
+#ifdef RUY_MAKEBLOCKMAP_DEBUG
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+#endif
+
+#include "ruy/check_macros.h"
+#include "ruy/opt_set.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+namespace {
+
+void DecodeTraversalLinear(int size_log2, std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1);
+ (*local_pos)[Side::kRhs] = square_index >> size_log2;
+}
+
+void DecodeTraversalFractalZ(std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ const std::uint32_t n1 = square_index;
+ const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) |
+ ((n1 & 0x22222222u) << 1);
+ const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) |
+ ((n2 & 0x0c0c0c0cu) << 2);
+ const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) |
+ ((n4 & 0x00f000f0u) << 4);
+ const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) |
+ ((n8 & 0x0000ff00u) << 8);
+ (*local_pos)[Side::kLhs] = n16 & 0xffff;
+ (*local_pos)[Side::kRhs] = n16 >> 16;
+}
+
+void DecodeTraversalFractalU(std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ DecodeTraversalFractalZ(square_index, local_pos);
+ // Change fractal z-order to u-order
+ (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs];
+}
+
+// Code inspired by the sample code in
+// https://en.wikipedia.org/wiki/Hilbert_curve
+// The main optimization is to avoid hard-to-predict conditional branches
+// based on the bits of the square_index parameter.
+void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index,
+ SidePair<int>* local_pos) {
+ std::uint32_t t = square_index;
+ std::uint32_t x = 0;
+ std::uint32_t y = 0;
+ // Easy-to-predict for loop, the number of iterations is the same for
+ // an entire GEMM.
+ for (int sb = 0; sb < size_log2; sb++) {
+ std::uint32_t s = 1 << sb;
+ bool rx = t & 2;
+ bool ry = (t & 1) ^ rx;
+ std::uint32_t tmp = rx ? (s - 1 - x) : x;
+ x = ry ? x : rx ? (s - 1 - y) : y;
+ y = ry ? (y + s) : tmp;
+ x = rx ? (x + s) : x;
+ t >>= 2;
+ }
+ (*local_pos)[Side::kLhs] = y;
+ (*local_pos)[Side::kRhs] = x;
+}
+
+} // end anonymous namespace
+
+void GetBlockByIndex(const BlockMap& block_map, int index,
+ SidePair<int>* block) {
+ profiler::ScopeLabel label("GetBlockByIndex");
+ const std::uint32_t index_u32 = index;
+
+ const std::uint32_t num_blocks_per_local_curve =
+ 1u << (2 * block_map.num_blocks_base_log2);
+ const std::uint32_t square_index =
+ index_u32 & (num_blocks_per_local_curve - 1);
+
+ const int size_log2 = block_map.num_blocks_base_log2;
+ SidePair<int> local_pos;
+ switch (block_map.traversal_order) {
+ case BlockMapTraversalOrder::kFractalZ:
+ DecodeTraversalFractalZ(square_index, &local_pos);
+ break;
+ case BlockMapTraversalOrder::kFractalU:
+ DecodeTraversalFractalU(square_index, &local_pos);
+ break;
+ case BlockMapTraversalOrder::kFractalHilbert:
+ DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos);
+ break;
+ default:
+ RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear);
+ DecodeTraversalLinear(size_log2, square_index, &local_pos);
+ break;
+ }
+
+ const std::uint32_t rectangular_index =
+ index_u32 >> 2 * block_map.num_blocks_base_log2;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1;
+ const int rectangular_offset = (rectangular_index & mask)
+ << block_map.num_blocks_base_log2;
+ (*block)[side] = local_pos[side] + rectangular_offset;
+ }
+}
+
+BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth,
+ int lhs_scalar_size,
+ int rhs_scalar_size,
+ int local_data_cache_size,
+ int shared_data_cache_size) {
+ const int kFractalOptSets =
+ RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT;
+ const int working_set_size =
+ (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth;
+ if (RUY_OPT_ENABLED(kFractalOptSets) &&
+ (working_set_size > local_data_cache_size)) {
+ if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) &&
+ (working_set_size > shared_data_cache_size)) {
+ return BlockMapTraversalOrder::kFractalHilbert;
+ } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) {
+ return BlockMapTraversalOrder::kFractalU;
+ } else {
+ return BlockMapTraversalOrder::kFractalZ;
+ }
+ } else {
+ return BlockMapTraversalOrder::kLinear;
+ }
+}
+
+namespace {
+
+int floor_log2_quotient(int num, int denom) {
+ if (num <= denom) {
+ return 0;
+ }
+ int log2_quotient = floor_log2(num) - ceil_log2(denom);
+ if ((denom << (log2_quotient + 1)) <= num) {
+ log2_quotient++;
+ }
+ return log2_quotient;
+}
+
+// Computes the rectangularness of the matrix shape (rows, cols). This is
+// essentially just the log2 of the quotient (rows / cols). The kernel_rows and
+// kernel_cols only get into the picture for clamping bounds but don't affect
+// the generic computation.
+void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols,
+ int* rows_rectangularness_log2,
+ int* cols_rectangularness_log2) {
+ *rows_rectangularness_log2 = 0;
+ *cols_rectangularness_log2 = 0;
+
+ // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel
+ // itself, we risk having too small kernel blocks for good kernel
+ // amortization. We avoid that by limiting recangularness so that kernel
+ // blocks are not too tiny at least in that dimension. Specifically, we try to
+ // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each
+ // kernel block along the large dimension.
+ const int min_kernel_inner_loop_runs_log2 = 3;
+ if (rows > cols) {
+ int cols_of_kernel_inner_loop_runs_log2 =
+ ceil_log2(cols) - pot_log2(kernel_cols);
+ int min_rows_of_kernel_inner_loop_runs_log2 =
+ std::max(0, min_kernel_inner_loop_runs_log2 -
+ cols_of_kernel_inner_loop_runs_log2);
+ *rows_rectangularness_log2 =
+ std::min(floor_log2_quotient(rows, cols),
+ std::max(0, floor_log2(rows) - pot_log2(kernel_rows) -
+ min_rows_of_kernel_inner_loop_runs_log2));
+ // Sanity check that we did not over-estimate rows_rectangularness_log2.
+ RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols);
+ } else if (cols > rows) {
+ int rows_of_kernel_inner_loop_runs_log2 =
+ ceil_log2(rows) - pot_log2(kernel_rows);
+ int min_cols_of_kernel_inner_loop_runs_log2 =
+ std::max(0, min_kernel_inner_loop_runs_log2 -
+ rows_of_kernel_inner_loop_runs_log2);
+ *cols_rectangularness_log2 =
+ std::min(floor_log2_quotient(cols, rows),
+ std::max(0, floor_log2(cols) - pot_log2(kernel_cols) -
+ min_cols_of_kernel_inner_loop_runs_log2));
+ // Sanity check that we did not over-estimate cols_rectangularness_log2.
+ RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows);
+ }
+ RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2);
+}
+
+// Computes a 'multithreading score'. When multithreading, we need there to
+// be at least as many tiles as there are threads, and hopefully
+// substantially more than that, so we benefit from ruy's ability to
+// dispatch fine-grained workloads to threads.
+int GetMultithreadingScore(int block_size_log2, int rows, int cols,
+ int tentative_thread_count) {
+ const int num_full_blocks_of_rows = rows >> block_size_log2;
+ const int num_full_blocks_of_cols = cols >> block_size_log2;
+ const int candidate_num_full_blocks_log2 = floor_log2(
+ std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols));
+
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (tentative_thread_count == 1) {
+ return 0;
+ } else {
+ const int blocks_per_thread_log2 =
+ candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count);
+ if (blocks_per_thread_log2 < 0) {
+ return -64;
+ } else if (blocks_per_thread_log2 == 0) {
+ return -16;
+ } else if (blocks_per_thread_log2 == 1) {
+ return -8;
+ } else if (blocks_per_thread_log2 == 2) {
+ return 0;
+ } else if (blocks_per_thread_log2 == 3) {
+ return 8;
+ } else {
+ return 16;
+ }
+ }
+}
+
+// Computes a 'cache locality score'.
+int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth,
+ int kernel_rows_log2, int kernel_cols_log2,
+ int lhs_scalar_size, int rhs_scalar_size, Path path,
+ int local_data_cache_size) {
+ // In the narrow case (e.g. matrix*vector), each byte of the big operand
+ // matrix (either LHS or RHS) is traversed only once, so any notion of data
+ // locality is irrelevant. Ignore the 'cache locality score' by forcing it to
+ // be 0 in that case.
+ if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) {
+ return 0;
+ }
+ const int block_rows = std::min(1 << block_size_log2, rows);
+ const int block_cols = std::min(1 << block_size_log2, cols);
+ const int total_read_bytes =
+ (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth;
+ const int total_read_bytes_log2 = ceil_log2(total_read_bytes);
+ const int nonlocality_log2 =
+ total_read_bytes_log2 - floor_log2(local_data_cache_size);
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (nonlocality_log2 < -1) {
+ return 64;
+ } else if (nonlocality_log2 == -1) {
+ return 56;
+ } else if (nonlocality_log2 == 0) {
+ return 48;
+ } else if (nonlocality_log2 == 1) {
+ return 32;
+ } else if (nonlocality_log2 == 2) {
+ return 16;
+ } else if (nonlocality_log2 == 3) {
+ return 0;
+ } else {
+ return -64;
+ }
+}
+
+// Compute a 'kernel amortization score'. This is the notion that very small
+// tiles result in more overhead outside of kernels, more complex memory
+// access patterns and less benefits from ruy's fat kernels, so we reward
+// larger blocks more than smaller ones.
+int GetKernelAmortizationScore(int block_size_log2, int rows, int cols,
+ int kernel_rows_log2, int kernel_cols_log2) {
+ const int block_rows = std::min(1 << block_size_log2, rows);
+ const int block_cols = std::min(1 << block_size_log2, cols);
+ const int kernels_per_block_log2 =
+ floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2;
+ RUY_DCHECK_GE(kernels_per_block_log2, 0);
+ // The values here have been tuned on ARM Cortex-A55.
+ // We expect this to have to be tuned differently for other CPUs.
+ if (kernels_per_block_log2 == 0) {
+ return 0;
+ } else if (kernels_per_block_log2 == 1) {
+ return 8;
+ } else if (kernels_per_block_log2 == 2) {
+ return 16;
+ } else if (kernels_per_block_log2 == 3) {
+ return 24;
+ } else if (kernels_per_block_log2 == 4) {
+ return 32;
+ } else if (kernels_per_block_log2 == 5) {
+ return 40;
+ } else if (kernels_per_block_log2 == 6) {
+ return 48;
+ } else if (kernels_per_block_log2 == 7) {
+ return 56;
+ } else {
+ return 64;
+ }
+}
+
+} // namespace
+
+void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
+ int tentative_thread_count, Path path,
+ int local_data_cache_size, int shared_data_cache_size,
+ BlockMap* block_map) {
+ profiler::ScopeLabel label("MakeBlockMap");
+
+#ifdef RUY_MAKEBLOCKMAP_DEBUG
+#if RUY_MAKEBLOCKMAP_DEBUG >= 2
+ static constexpr bool debug_everytime = true;
+#else
+ static constexpr bool debug_everytime = false;
+#endif
+ static bool firsttime = true;
+ if (firsttime || debug_everytime) {
+ fprintf(stderr,
+ "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, "
+ "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, "
+ "tentative_thread_count=%d)\n",
+ rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size,
+ rhs_scalar_size, tentative_thread_count);
+ }
+#endif
+
+ RUY_DCHECK_GE(rows, kernel_rows);
+ RUY_DCHECK_GE(cols, kernel_cols);
+ RUY_DCHECK_EQ(rows % kernel_rows, 0);
+ RUY_DCHECK_EQ(cols % kernel_cols, 0);
+
+ block_map->traversal_order =
+ GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
+ local_data_cache_size, shared_data_cache_size);
+
+ int rows_rectangularness_log2 = 0;
+ int cols_rectangularness_log2 = 0;
+ GetRectangularness(rows, cols, kernel_rows, kernel_cols,
+ &rows_rectangularness_log2, &cols_rectangularness_log2);
+
+ const int kernel_rows_log2 = pot_log2(kernel_rows);
+ const int kernel_cols_log2 = pot_log2(kernel_cols);
+ const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2);
+
+ const int size = std::min(rows, cols);
+ const int size_log2 = std::max(kernel_size_log2, floor_log2(size));
+
+ RUY_DCHECK_GE(size_log2, kernel_size_log2);
+
+ // We are going to try candidate values for block_size_log2 ranging from
+ // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2).
+ // For each of them we will compute a 'score' by adding individual scores
+ // for a few different considerations, all of which is entirely empirical.
+ // The values (and possibly the logic) around here are all subject to tuning
+ // based on benchmarks on different hardware. The current values are based
+ // on benchmarking on Qualcomm S855 (big and little cores), arm64,
+ // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead
+ // and tune this as needed to achieve good performance elsewhere. Use
+ // the unit test, block_map_test, to encode values that should be preserved
+ // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this.
+ static constexpr int kMaxKernelsPerBlockLog2 = 6;
+ const int max_block_size_log2 =
+ std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2);
+ int best_score = std::numeric_limits<int>::min();
+ int best_score_block_size_log2 = -1;
+ for (int block_size_log2 = kernel_size_log2;
+ block_size_log2 <= max_block_size_log2; block_size_log2++) {
+ const int multithreading_score = GetMultithreadingScore(
+ block_size_log2, rows, cols, tentative_thread_count);
+ const int cache_locality_score = GetCacheLocalityScore(
+ block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2,
+ lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size);
+ const int kernel_amortization_score = GetKernelAmortizationScore(
+ block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2);
+ const int score =
+ multithreading_score + cache_locality_score + kernel_amortization_score;
+#ifdef RUY_MAKEBLOCKMAP_DEBUG
+ if (firsttime || debug_everytime) {
+ fprintf(stderr,
+ "block_size_log2=%d: score=%d multithreading_score=%d "
+ "cache_locality_score=%d kernel_amortization_score=%d\n",
+ block_size_log2, score, multithreading_score,
+ cache_locality_score, kernel_amortization_score);
+ }
+#endif
+ if (score >= best_score) {
+ best_score = score;
+ best_score_block_size_log2 = block_size_log2;
+ }
+ }
+
+#ifdef RUY_MAKEBLOCKMAP_DEBUG
+ if (firsttime || debug_everytime) {
+ fprintf(stderr, "best_score_block_size_log2=%d\n",
+ best_score_block_size_log2);
+ }
+
+ static const char* explicit_block_size_log2_env =
+ getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2");
+ if (explicit_block_size_log2_env) {
+ best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env);
+ if (firsttime || debug_everytime) {
+ fprintf(stderr, "Overridden best_score_block_size_log2=%d\n",
+ best_score_block_size_log2);
+ }
+ }
+ firsttime = false;
+#endif
+
+ int num_blocks_base_log2 = size_log2 - best_score_block_size_log2;
+ RUY_DCHECK_GE(num_blocks_base_log2, 0);
+
+ const int num_blocks_of_rows_log2 =
+ num_blocks_base_log2 + rows_rectangularness_log2;
+ const int num_blocks_of_cols_log2 =
+ num_blocks_base_log2 + cols_rectangularness_log2;
+
+ const int smallr =
+ round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows);
+ const int smallc =
+ round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols);
+ const int missr =
+ round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >>
+ pot_log2(kernel_rows);
+ const int missc =
+ round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >>
+ pot_log2(kernel_cols);
+
+ block_map->dims[Side::kLhs] = rows;
+ block_map->dims[Side::kRhs] = cols;
+ block_map->kernel_dims[Side::kLhs] = kernel_rows;
+ block_map->kernel_dims[Side::kRhs] = kernel_cols;
+ block_map->num_blocks_base_log2 = num_blocks_base_log2;
+ block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2;
+ block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2;
+ block_map->small_block_dims[Side::kLhs] = smallr;
+ block_map->small_block_dims[Side::kRhs] = smallc;
+ block_map->large_blocks[Side::kLhs] = missr;
+ block_map->large_blocks[Side::kRhs] = missc;
+ // Done last: NumBlocks needs some of the block_map fields to be already set.
+ block_map->thread_count =
+ std::min(tentative_thread_count, NumBlocks(*block_map));
+}
+
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end) {
+ profiler::ScopeLabel label("GetBlockMatrixCoords");
+ *start = block * block_map.small_block_dims[side] +
+ std::min(block, block_map.large_blocks[side]) *
+ block_map.kernel_dims[side];
+ *end =
+ *start + block_map.small_block_dims[side] +
+ (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
+
+ RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
+ RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
+ RUY_DCHECK_LE(*end, block_map.dims[side]);
+ RUY_DCHECK_LT(*start, *end);
+ RUY_DCHECK_GE(*start, 0);
+}
+
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+ SidePair<int>* start, SidePair<int>* end) {
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
+ &(*end)[side]);
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/block_map.h b/ruy/block_map.h
new file mode 100644
index 0000000..5e1cee0
--- /dev/null
+++ b/ruy/block_map.h
@@ -0,0 +1,161 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_
+
+#include "ruy/path.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+
+enum class BlockMapTraversalOrder {
+ // Plain old row-by-row or column-by-column traversal.
+ kLinear,
+ // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve
+ kFractalZ,
+ // Variant of Z-order doing a U instead of a Z.
+ kFractalU,
+ // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve
+ kFractalHilbert
+};
+
+// A BlockMap describes a tiling of a matrix, typically the destination matrix
+// of a matrix multiplication computation. As is standard in matrix
+// multiplication, a tile is called a "block".
+//
+// Ruy subdivides work by blocks of the destination matrix: each thread fully
+// computes a block at once, then moves on to another block; each block is
+// produced by a single thread.
+//
+// This ensures that the workloads for each block are mutually independent,
+// which reduces synchronization requirements.
+//
+// Typically, a matrix multiplication will early on create a BlockMap by
+// calling MakeBlockMap. It will then query the number of blocks in that
+// BlockMap by calling NumBlocks. It will then create a single atomic integer
+// counter indexing these blocks, called the 'index', and will distribute
+// work to its N threads by ensuring that each thread works on disjoint sets
+// of index values. For a given index value, the thread will call
+// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords
+// to find the actual row and column numbers of this block.
+//
+// There are two nested levels of subdivision. On a local level, the matrix is
+// tiled into a square NxN grid where N is a power of two, specifically:
+// N = 2^num_blocks_base_log2.
+//
+// At a larger scale, around these blocks, there may be one further
+// level of subdivision, in only one dimension: either along rows or along
+// columns. That is used to handle arbitrarily rectangular matrices. The
+// aforementioned high-level block grid is square, so it does not readily fit
+// well very rectangular matrices.
+//
+// Taking together these two nested levels of subdivision, the effective
+// tiling is by
+// 2^(num_blocks_base_log2 + rows_rectangularness_log2)
+// blocks in the row dimension, and by
+// 2^(num_blocks_base_log2 + cols_rectangularness_log2)
+// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols.
+//
+// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero.
+//
+// Finally, this BlockMap is designed to operate under alignment constraints:
+// two fields, kernel_rows and kernel_cols, describe the requested alignment
+// of the effective grid in both dimensions. The idea is to feed matrix
+// multiplication kernels with tiles that fit their width as much as possible.
+// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp.
+// kernel_cols) then some tile will have to have unaligned size. BlockMap
+// will only allow that to happen in the last position along each axis, so
+// as to minimize the overhead incurred onto the matrix multiplication kernels.
+struct BlockMap {
+ // The number of threads to use (to distribute the blocks to).
+ int thread_count;
+ // The order in which to traverse the matrix of which this BlockMap represents
+ // a tiling (hereafter "the matrix").
+ BlockMapTraversalOrder traversal_order;
+ // The dimensions of the block_map, that is, of the destination
+ // matrix rounded up to next multiples of kernel_dims.
+ SidePair<int> dims;
+ // Log2 of the minimum number of subdivisions of the grid along either axis.
+ int num_blocks_base_log2;
+ // Log2 of the additional subdivision of the rows/columns axis.
+ SidePair<int> rectangularness_log2;
+ // Requested alignment of the subdivisions of the grid along the rows/columns
+ // axis.
+ SidePair<int> kernel_dims;
+ // Internal helper. Minimum number of rows/columns in each block.
+ SidePair<int> small_block_dims;
+ // Internal helper. Number of blocks along each dimension that need to have
+ // their size in that dimension be given by (small_block_dims + kernel_dims)
+ // instead of just small_block_dims.
+ SidePair<int> large_blocks;
+};
+
+// Returns the traversal order to be used for the given matrix multiplication
+// parameters.
+BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth,
+ int lhs_scalar_size,
+ int rhs_scalar_size,
+ int local_data_cache_size,
+ int shared_data_cache_size);
+
+// Create a BlockMap suitable for tiling the destination matrix in a
+// matrix multiplication with the given parameters.
+void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size, int rhs_scalar_size,
+ int tentative_thread_count, Path path,
+ int local_data_cache_size, int shared_data_cache_size,
+ BlockMap* block_map);
+
+// Maps an integer index to a block position in the grid.
+void GetBlockByIndex(const BlockMap& block_map, int index,
+ SidePair<int>* block);
+
+// Given a block position in the grid, returns its actual
+// position in the matrix that the BlockMap refers to in the dimension
+// referred to by `side`: along rows if side==kLhs, along columns if
+// side==kRhs.
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end);
+
+// Given a block position in the grid, returns its actual
+// position in the matrix that the BlockMap refers to in terms of
+// actual row/column indices.
+void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
+ SidePair<int>* start, SidePair<int>* end);
+
+// Returns the number of grid subdivisions along the rows dimension (if
+// side == kLhs) or columns dimension (if side == kRhs).
+inline int NumBlocksPerSide(Side side, const BlockMap& block_map) {
+ return 1 << (block_map.num_blocks_base_log2 +
+ block_map.rectangularness_log2[side]);
+}
+
+// Returns the overall number of blocks in
+// the BlockMap. The valid index values to pass to GetBlockByIndex are the
+// integers from 0 to N-1 where N is the value returned here.
+//
+// Note that it is always true that
+// NumBlocks == NumBlocksOfRows * NumBlocksOfCols
+// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0.
+inline int NumBlocks(const BlockMap& block_map) {
+ return 1 << (2 * block_map.num_blocks_base_log2 +
+ block_map.rectangularness_log2[Side::kLhs] +
+ block_map.rectangularness_log2[Side::kRhs]);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_
diff --git a/ruy/block_map_test.cc b/ruy/block_map_test.cc
new file mode 100644
index 0000000..24646cf
--- /dev/null
+++ b/ruy/block_map_test.cc
@@ -0,0 +1,263 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/block_map.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+#include <vector>
+
+#include "testing/base/public/gunit.h"
+#include "ruy/cpu_cache_size.h"
+#include "ruy/path.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+namespace {
+
+#if RUY_PLATFORM(NEON_64)
+
+// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55.
+void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows,
+ int kernel_cols, int lhs_scalar_size,
+ int rhs_scalar_size, int tentative_thread_count,
+ Path path, int expected_num_blocks_base_log2,
+ int expected_rectangularness_log2) {
+ BlockMap block_map;
+ MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size,
+ rhs_scalar_size, tentative_thread_count, path,
+ LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map);
+ EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2);
+ EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs],
+ block_map.rectangularness_log2[Side::kRhs]),
+ 0);
+ EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs],
+ block_map.rectangularness_log2[Side::kRhs]),
+ expected_rectangularness_log2);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) {
+ MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest,
+ MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) {
+ MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4,
+ Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 1,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 2,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 2,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTest32bit) {
+ MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 3,
+ /* expected_rectangularness_log2 */ 0);
+ MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4,
+ /* tentative_thread_count */ 4, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 7,
+ /* expected_rectangularness_log2 */ 0);
+}
+
+TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) {
+ MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 3);
+ MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1,
+ /* tentative_thread_count */ 1, Path::kNeonDotprod,
+ /* expected_num_blocks_base_log2 */ 0,
+ /* expected_rectangularness_log2 */ 6);
+}
+
+#endif
+
+int L1Distance(const SidePair<int>& a, const SidePair<int>& b) {
+ return std::abs(a[Side::kLhs] - b[Side::kLhs]) +
+ std::abs(a[Side::kRhs] - b[Side::kRhs]);
+}
+
+void GetBlockByIndexSquareTest(int num_blocks_base_log2,
+ BlockMapTraversalOrder traversal_order) {
+ // Arbitrary, does not affect this test. 3 is just a typical value.
+ constexpr int kKernelSizeLog2 = 3;
+
+ const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2;
+ BlockMap block_map;
+ block_map.thread_count = 1;
+ block_map.traversal_order = traversal_order;
+ block_map.num_blocks_base_log2 = num_blocks_base_log2;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ block_map.dims[side] = 1 << size_log2;
+ block_map.rectangularness_log2[side] = 0;
+ block_map.kernel_dims[side] = 1 << kKernelSizeLog2;
+ block_map.small_block_dims[side] = block_map.kernel_dims[side];
+ block_map.large_blocks[side] = 0;
+ }
+
+ const int num_blocks_per_side = 1 << num_blocks_base_log2;
+ const int num_blocks = num_blocks_per_side * num_blocks_per_side;
+ EXPECT_EQ(num_blocks, NumBlocks(block_map));
+
+ // Perform a full traversal of all blocks, as if computing a whole matrix
+ // multiplication.
+ //
+ // Used to record how many times each block was hit by the traversal.
+ std::vector<int> block_hit_counts(num_blocks);
+ // Here we guard an assumption that all traversal orders start at (0, 0).
+ SidePair<int> previous_block_coords(0, 0);
+ // Sum of L1 norm of the coordinate change at every step of the traversal.
+ std::int64_t total_l1_distance = 0;
+ // Number of jumps i.e. traversal steps with a L1 norm greater than 1.
+ int discontinuity_count = 0;
+ for (int block_index = 0; block_index < num_blocks; block_index++) {
+ SidePair<int> block_coords;
+ GetBlockByIndex(block_map, block_index, &block_coords);
+ ++block_hit_counts[block_coords[Side::kLhs] +
+ num_blocks_per_side * block_coords[Side::kRhs]];
+ int distance = L1Distance(block_coords, previous_block_coords);
+ total_l1_distance += distance;
+ discontinuity_count += (distance > 1);
+ previous_block_coords = block_coords;
+ }
+
+ // Verify that each block was traversed exactly once.
+ for (int l = 0; l < num_blocks_per_side; l++) {
+ for (int r = 0; r < num_blocks_per_side; r++) {
+ EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1);
+ }
+ }
+
+ // Verify that the discontinuity_count and total_l1_distance are as expected
+ // for the given traversal_order.
+ switch (traversal_order) {
+ case BlockMapTraversalOrder::kFractalHilbert:
+ // No discontinuity at all with this space-filling continuous curve!
+ EXPECT_EQ(discontinuity_count, 0);
+ // Therefore, total_l1_distance has to be the number of blocks minus one.
+ EXPECT_EQ(total_l1_distance, num_blocks - 1);
+ break;
+ case BlockMapTraversalOrder::kLinear:
+ EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1);
+ EXPECT_EQ(total_l1_distance,
+ 2 * num_blocks_per_side * (num_blocks_per_side - 1));
+ break;
+ case BlockMapTraversalOrder::kFractalZ:
+ EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0);
+ EXPECT_EQ(total_l1_distance,
+ 2 * num_blocks_per_side * (num_blocks_per_side - 1));
+ break;
+ case BlockMapTraversalOrder::kFractalU: {
+ if (num_blocks_base_log2 == 0) {
+ EXPECT_EQ(discontinuity_count, 0);
+ EXPECT_EQ(total_l1_distance, 0);
+ } else {
+ int expected_discontinuity_count = 0;
+ int expected_total_l1_distance = 3;
+ for (int i = 2; i <= num_blocks_base_log2; i++) {
+ expected_discontinuity_count = 4 * expected_discontinuity_count + 2;
+ expected_total_l1_distance =
+ 4 * expected_total_l1_distance + (1 << (i + 1)) - 1;
+ }
+ EXPECT_EQ(discontinuity_count, expected_discontinuity_count);
+ EXPECT_EQ(total_l1_distance, expected_total_l1_distance);
+ }
+ break;
+ }
+ default:
+ abort();
+ }
+}
+
+TEST(BlockMapTest, GetBlockByIndexSquare) {
+ for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10;
+ num_blocks_base_log2++) {
+ for (BlockMapTraversalOrder traversal_order :
+ {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ,
+ BlockMapTraversalOrder::kFractalU,
+ BlockMapTraversalOrder::kFractalHilbert}) {
+ GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order);
+ }
+ }
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/blocking_counter.cc b/ruy/blocking_counter.cc
new file mode 100644
index 0000000..ffa7ac0
--- /dev/null
+++ b/ruy/blocking_counter.cc
@@ -0,0 +1,49 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/blocking_counter.h"
+
+#include "ruy/check_macros.h"
+#include "ruy/wait.h"
+
+namespace ruy {
+
+void BlockingCounter::Reset(int initial_count) {
+ int old_count_value = count_.load(std::memory_order_relaxed);
+ RUY_DCHECK_EQ(old_count_value, 0);
+ (void)old_count_value;
+ count_.store(initial_count, std::memory_order_release);
+}
+
+bool BlockingCounter::DecrementCount() {
+ int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel);
+ RUY_DCHECK_GT(old_count_value, 0);
+ int count_value = old_count_value - 1;
+ bool hit_zero = (count_value == 0);
+ if (hit_zero) {
+ std::lock_guard<std::mutex> lock(count_mutex_);
+ count_cond_.notify_all();
+ }
+ return hit_zero;
+}
+
+void BlockingCounter::Wait() {
+ const auto& condition = [this]() {
+ return count_.load(std::memory_order_acquire) == 0;
+ };
+ ruy::Wait(condition, &count_cond_, &count_mutex_);
+}
+
+} // namespace ruy
diff --git a/ruy/blocking_counter.h b/ruy/blocking_counter.h
new file mode 100644
index 0000000..878f0e7
--- /dev/null
+++ b/ruy/blocking_counter.h
@@ -0,0 +1,62 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_
+
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11) // IWYU pragma: keep
+#include <mutex> // NOLINT(build/c++11) // IWYU pragma: keep
+
+namespace ruy {
+
+// A BlockingCounter lets one thread to wait for N events to occur.
+// This is how the master thread waits for all the worker threads
+// to have finished working.
+// The waiting is done using a naive spinlock waiting for the atomic
+// count_ to hit the value 0. This is acceptable because in our usage
+// pattern, BlockingCounter is used only to synchronize threads after
+// short-lived tasks (performing parts of the same GEMM). It is not used
+// for synchronizing longer waits (resuming work on the next GEMM).
+class BlockingCounter {
+ public:
+ BlockingCounter() : count_(0) {}
+
+ // Sets/resets the counter; initial_count is the number of
+ // decrementing events that the Wait() call will be waiting for.
+ void Reset(int initial_count);
+
+ // Decrements the counter; if the counter hits zero, signals
+ // the threads that were waiting for that, and returns true.
+ // Otherwise (if the decremented count is still nonzero),
+ // returns false.
+ bool DecrementCount();
+
+ // Waits for the N other threads (N having been set by Reset())
+ // to hit the BlockingCounter.
+ void Wait();
+
+ private:
+ std::atomic<int> count_;
+
+ // The condition variable and mutex allowing to passively wait for count_
+ // to reach the value zero, in the case of longer waits.
+ std::condition_variable count_cond_;
+ std::mutex count_mutex_;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_
diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl
new file mode 100644
index 0000000..964ede3
--- /dev/null
+++ b/ruy/build_defs.bzl
@@ -0,0 +1,54 @@
+"""Build definitions for Ruy.
+
+In some cases these are used to configure specific targets for
+specific platforms, and dispatch is based on runtime capability detection.
+"""
+
+# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support
+# ARM32 without NEON then we'll implement runtime detection and dispatch at that point.
+# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size".
+
+def ruy_copts_base():
+ return select({
+ ":armeabi-v7a": [
+ "-mfpu=neon",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ ":optimized": ["-O3"],
+ "//conditions:default": [],
+ })
+
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_skylake():
+ return select({
+ ":x86_64": ["-march=skylake-avx512"],
+ "//conditions:default": [],
+ })
+
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_avx2():
+ return select({
+ ":x86_64": ["-mavx2", "-mfma"],
+ "//conditions:default": [],
+ })
+
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_sse42():
+ return []
+
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_avxvnni():
+ return select({
+ # TODO(b/146494398): Reinstate flag, something like "-march=cascadelake".
+ ":x86_64": [],
+ "//conditions:default": [],
+ })
diff --git a/ruy/build_defs.bzl.opensource b/ruy/build_defs.bzl.opensource
new file mode 100644
index 0000000..9bccccf
--- /dev/null
+++ b/ruy/build_defs.bzl.opensource
@@ -0,0 +1,40 @@
+"""Build definitions for Ruy."""
+
+# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support
+# ARM32 without NEON then we'll implement runtime detection and dispatch at that point.
+# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size".
+
+def ruy_copts_base():
+ return select({
+ ":armeabi-v7a": [
+ "-mfpu=neon",
+ ],
+ "//conditions:default": [],
+ }) + select({
+ ":optimized": ["-O3"],
+ "//conditions:default": [],
+ })
+
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_skylake():
+ return []
+
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_avx2():
+ return []
+
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_sse42():
+ return []
+
+# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+# Optimization is not finished. In particular the dimensions of the kernel
+# blocks can be changed as desired.
+#
+# Used for targets that are compiled with extra features that are skipped at runtime if unavailable.
+def ruy_copts_avxvnni():
+ return []
diff --git a/ruy/check_macros.h b/ruy/check_macros.h
new file mode 100644
index 0000000..773f37d
--- /dev/null
+++ b/ruy/check_macros.h
@@ -0,0 +1,138 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_
+
+#include <cstdio>
+#include <cstdlib>
+#include <type_traits>
+
+namespace ruy {
+namespace check_macros {
+
+constexpr int kValueBufSize = 32;
+
+template <typename T, typename Enable = void>
+struct ToString {
+ static void Run(const T& value, char* buf) {
+ snprintf(buf, kValueBufSize, "(?)");
+ }
+};
+
+template <>
+struct ToString<float, void> {
+ static void Run(float value, char* buf) {
+ snprintf(buf, kValueBufSize, "%.9g", static_cast<double>(value));
+ }
+};
+
+template <>
+struct ToString<double, void> {
+ static void Run(double value, char* buf) {
+ snprintf(buf, kValueBufSize, "%.16g", value);
+ }
+};
+
+template <typename T>
+struct ToString<T, typename std::enable_if<std::is_integral<T>::value>::type> {
+ static void Run(const T& value, char* buf) {
+ snprintf(buf, kValueBufSize, "%lld", static_cast<long long>(value));
+ }
+};
+
+template <typename T>
+struct ToString<T*, void> {
+ static void Run(T* value, char* buf) {
+ snprintf(buf, kValueBufSize, "%p", value);
+ }
+};
+
+template <typename T>
+struct ToString<T, typename std::enable_if<std::is_enum<T>::value>::type> {
+ static void Run(const T& value, char* buf) {
+ snprintf(buf, kValueBufSize, "(enum value %d)", static_cast<int>(value));
+ }
+};
+
+inline void Failure(const char* file, int line, const char* macro,
+ const char* condition) {
+ fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro,
+ condition);
+ abort();
+}
+
+template <typename LhsType, typename RhsType>
+inline void Failure(const char* file, int line, const char* macro,
+ const char* lhs, const LhsType& lhs_value, const char* op,
+ const char* rhs, const RhsType& rhs_value) {
+ char lhs_value_buf[kValueBufSize];
+ ToString<LhsType>::Run(lhs_value, lhs_value_buf);
+ char rhs_value_buf[kValueBufSize];
+ ToString<RhsType>::Run(rhs_value, rhs_value_buf);
+ fprintf(stderr,
+ "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ "
+ "%s %s %s ].\n",
+ file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf);
+ abort();
+}
+
+#define RUY_CHECK_IMPL(macro, condition) \
+ do { \
+ if (!(condition)) { \
+ ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \
+ } \
+ } while (false)
+
+#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \
+ do { \
+ const auto& lhs_value = (lhs); \
+ const auto& rhs_value = (rhs); \
+ if (!(lhs_value op rhs_value)) { \
+ ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \
+ #op, #rhs, rhs_value); \
+ } \
+ } while (false)
+
+#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition)
+#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y)
+#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y)
+#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y)
+#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y)
+#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y)
+#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y)
+
+#ifdef NDEBUG
+#define RUY_DCHECK(condition)
+#define RUY_DCHECK_EQ(x, y)
+#define RUY_DCHECK_NE(x, y)
+#define RUY_DCHECK_GE(x, y)
+#define RUY_DCHECK_GT(x, y)
+#define RUY_DCHECK_LE(x, y)
+#define RUY_DCHECK_LT(x, y)
+#else
+#define RUY_DCHECK(condition) RUY_CHECK(condition)
+#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y)
+#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y)
+#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y)
+#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y)
+#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y)
+#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y)
+#endif
+
+} // end namespace check_macros
+} // end namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_
diff --git a/ruy/check_macros_test.cc b/ruy/check_macros_test.cc
new file mode 100644
index 0000000..7e47e7f
--- /dev/null
+++ b/ruy/check_macros_test.cc
@@ -0,0 +1,153 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/check_macros.h"
+
+#include "testing/base/public/gunit.h"
+
+namespace {
+
+#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \
+ do { \
+ if (vacuously_succeeds || (condition)) { \
+ RUY_##family(condition); \
+ } \
+ } while (false)
+
+#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \
+ y) \
+ do { \
+ if (vacuously_succeeds || ((x)op(y))) { \
+ RUY_##family##_##op_name(x, y); \
+ } \
+ } while (false)
+
+#ifdef NDEBUG
+#define TEST_CONDITION(condition) \
+ do { \
+ TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \
+ } while (false)
+#define TEST_COMPARISON(op_name, x, op, y) \
+ do { \
+ TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \
+ } while (false)
+#else
+#define TEST_CONDITION(condition) \
+ do { \
+ TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \
+ TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \
+ } while (false)
+#define TEST_COMPARISON(op_name, x, op, y) \
+ do { \
+ TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \
+ TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \
+ } while (false)
+
+#endif
+
+template <typename LhsType, typename RhsType>
+void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) {
+ RUY_CHECK_EQ(lhs, lhs);
+ TEST_COMPARISON(EQ, lhs, ==, lhs);
+ RUY_CHECK_EQ(lhs, lhs);
+ RUY_CHECK_EQ(lhs, lhs);
+ if (lhs == rhs) {
+ RUY_CHECK_EQ(lhs, rhs);
+ }
+ if (lhs != rhs) {
+ RUY_CHECK_NE(lhs, rhs);
+ }
+}
+
+template <typename LhsType, typename RhsType>
+void TestComparisons(const LhsType& lhs, const RhsType& rhs) {
+ TestEqualityComparisons(lhs, rhs);
+ if (lhs > rhs) {
+ RUY_CHECK_GT(lhs, rhs);
+ }
+ if (lhs >= rhs) {
+ RUY_CHECK_GE(lhs, rhs);
+ }
+ if (lhs < rhs) {
+ RUY_CHECK_LT(lhs, rhs);
+ }
+ if (lhs <= rhs) {
+ RUY_CHECK_LE(lhs, rhs);
+ }
+}
+
+TEST(CheckMacrosTest, IntInt) {
+ TestComparisons(0, 0);
+ TestComparisons(0, 1);
+ TestComparisons(1, -1);
+ TestComparisons(-1, 0);
+ TestComparisons(123, -456);
+ TestComparisons(std::numeric_limits<int>::min(),
+ std::numeric_limits<int>::max());
+ TestComparisons(123, std::numeric_limits<int>::max());
+ TestComparisons(123, std::numeric_limits<int>::min());
+}
+
+TEST(CheckMacrosTest, Uint8Uint8) {
+ TestComparisons<std::uint8_t, std::uint8_t>(0, 0);
+ TestComparisons<std::uint8_t, std::uint8_t>(255, 0);
+ TestComparisons<std::uint8_t, std::uint8_t>(0, 255);
+ TestComparisons<std::uint8_t, std::uint8_t>(12, 34);
+}
+
+TEST(CheckMacrosTest, Uint8Int) {
+ TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::min());
+ TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::min());
+ TestComparisons<std::uint8_t, int>(0, std::numeric_limits<int>::max());
+ TestComparisons<std::uint8_t, int>(255, std::numeric_limits<int>::max());
+}
+
+TEST(CheckMacrosTest, FloatFloat) {
+ TestComparisons(0.f, 0.f);
+ TestComparisons(0.f, 1.f);
+ TestComparisons(1.f, -1.f);
+ TestComparisons(-1.f, 0.f);
+ TestComparisons(123.f, -456.f);
+ TestComparisons(std::numeric_limits<float>::lowest(),
+ std::numeric_limits<float>::max());
+ TestComparisons(123.f, std::numeric_limits<float>::max());
+ TestComparisons(123.f, std::numeric_limits<float>::lowest());
+}
+
+TEST(CheckMacrosTest, IntFloat) {
+ TestComparisons(0, 0.f);
+ TestComparisons(0, 1.f);
+ TestComparisons(1, -1.f);
+ TestComparisons(-1, 0.f);
+ TestComparisons(123, -456.f);
+ TestComparisons(std::numeric_limits<int>::lowest(),
+ std::numeric_limits<float>::max());
+ TestComparisons(123, std::numeric_limits<float>::max());
+ TestComparisons(123, std::numeric_limits<float>::lowest());
+}
+
+TEST(CheckMacrosTest, EnumClass) {
+ enum class SomeEnumClass { kA, kB, kC };
+ TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA);
+ TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB);
+ TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB);
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/common.h b/ruy/common.h
new file mode 100644
index 0000000..1cd40fe
--- /dev/null
+++ b/ruy/common.h
@@ -0,0 +1,73 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Miscellaneous helpers internal library.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_
+
+#include <limits>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+
+#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD)
+#define RUY_PREFETCH_LOAD(X) X
+#else
+#define RUY_PREFETCH_LOAD(X)
+#endif
+
+#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE)
+#define RUY_PREFETCH_STORE(X) X
+#else
+#define RUY_PREFETCH_STORE(X)
+#endif
+
+#define RUY_STR(s) RUY_STR_UNEXPANDED(s)
+#define RUY_STR_UNEXPANDED(s) #s
+
+namespace ruy {
+
+// Helper for type-erasing a pointer.
+//
+// Often inside Ruy, a template parameter holds type information statically, but
+// we would like to have a function signature that doesn't depend on the
+// template parameters, so that we can dispatch indirectly across multiple
+// implementations. This helper is at the core of such type-erasure.
+//
+// The opposite of this operation is just `static_cast<T*>(void_ptr)`.
+template <typename T>
+void* ToVoidPtr(T* p) {
+ return const_cast<void*>(static_cast<const void*>(p));
+}
+
+template <typename Scalar>
+Scalar SymmetricZeroPoint() {
+ if (std::is_floating_point<Scalar>::value) {
+ return 0;
+ }
+ if (std::is_signed<Scalar>::value) {
+ return 0;
+ }
+ return std::numeric_limits<Scalar>::max() / 2 + 1;
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_
diff --git a/ruy/context.cc b/ruy/context.cc
new file mode 100644
index 0000000..1a70303
--- /dev/null
+++ b/ruy/context.cc
@@ -0,0 +1,109 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/context.h"
+
+#include "ruy/check_macros.h"
+#include "ruy/detect_arm.h"
+#include "ruy/detect_x86.h"
+#include "ruy/have_built_path_for.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+
+void Context::SetRuntimeEnabledPaths(Path paths) {
+ runtime_enabled_paths_ = paths;
+}
+
+Path Context::GetRuntimeEnabledPaths() {
+ // This function should always return the same value on a given machine.
+ // When runtime_enabled_paths_ has its initial value kNone, it performs
+ // some platform detection to resolve it to specific Path values.
+
+ // Fast path: already resolved.
+ if (runtime_enabled_paths_ != Path::kNone) {
+ return runtime_enabled_paths_;
+ }
+
+ // Need to resolve now. Start by considering all paths enabled.
+ runtime_enabled_paths_ = kAllPaths;
+
+ // This mechanism is intended to be used for testing and benchmarking. For
+ // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to
+ // evaluate AVX2 performance on an AVX-512 machine.
+#ifdef RUY_FORCE_DISABLE_PATHS
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS);
+#endif
+
+#if RUY_PLATFORM(ARM)
+ // Now selectively disable paths that aren't supported on this machine.
+ if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) {
+ if (!DetectDotprod()) {
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod;
+ // Sanity check.
+ RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone);
+ }
+ }
+#endif // RUY_PLATFORM(ARM)
+
+#if RUY_PLATFORM(X86)
+ // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete /
+ // placeholder. Optimization is not finished. In particular the dimensions of
+ // the kernel blocks can be changed as desired.
+ //
+ if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) {
+ if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) {
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42;
+ // Sanity check.
+ RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone);
+ }
+ }
+
+ if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) {
+ if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) {
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2;
+ // Sanity check.
+ RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone);
+ }
+ }
+
+ if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) {
+ if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) {
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512;
+ // Sanity check.
+ RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone);
+ }
+ }
+
+ // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete /
+ // placeholder. Optimization is not finished. In particular the dimensions of
+ // the kernel blocks can be changed as desired.
+ //
+ if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) {
+ if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) {
+ runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni;
+ // Sanity check.
+ RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone);
+ }
+ }
+#endif // RUY_PLATFORM(X86)
+
+ // Sanity check. We can't possibly have disabled all paths, as some paths
+ // are universally available (kReference, kStandardCpp).
+ RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone);
+ return runtime_enabled_paths_;
+}
+
+} // namespace ruy
diff --git a/ruy/context.h b/ruy/context.h
new file mode 100644
index 0000000..330a7e7
--- /dev/null
+++ b/ruy/context.h
@@ -0,0 +1,109 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_
+
+#include <cstddef>
+#include <memory>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/path.h"
+#include "ruy/prepacked_cache.h"
+#include "ruy/thread_pool.h"
+#include "ruy/trace.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+// The state private to each Ruy thread.
+struct PerThreadState {
+ // Each thread may be running on a different microarchitecture. For example,
+ // some threads may be on big cores, while others are on little cores. Thus,
+ // it's best for the tuning to be per-thread.
+ TuningResolver tuning_resolver;
+ // Each thread has its own local allocator.
+ Allocator allocator;
+};
+
+// A Context holds runtime information used by Ruy. It holds runtime resources
+// such as the workers thread pool and the allocator (which holds buffers for
+// temporary data), as well as runtime options controlling which Paths are
+// enabled (typically based on which instruction sets are detected) and how
+// many threads to use.
+struct Context final {
+ Path last_taken_path = Path::kNone;
+ Tuning explicit_tuning = Tuning::kAuto;
+ // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy.
+ ThreadPool workers_pool;
+ int max_num_threads = 1;
+ // State for each thread in the thread pool. Entry 0 is the main thread.
+ std::vector<std::unique_ptr<PerThreadState>> per_thread_states;
+ TracingContext tracing;
+ CachePolicy cache_policy = CachePolicy::kNoCache;
+
+ Allocator* GetMainAllocator() {
+ if (!main_allocator_) {
+ main_allocator_.reset(new Allocator);
+ }
+ return main_allocator_.get();
+ }
+
+ PrepackedCache* GetPrepackedCache() {
+ if (!prepacked_cache_) {
+ prepacked_cache_.reset(new PrepackedCache);
+ }
+ return prepacked_cache_.get();
+ }
+
+ void ClearPrepackedCache() { prepacked_cache_ = nullptr; }
+
+ void EnsureNPerThreadStates(int thread_count) {
+ while (per_thread_states.size() < static_cast<std::size_t>(thread_count)) {
+ per_thread_states.emplace_back(new PerThreadState);
+ }
+ }
+
+ Tuning GetMainThreadTuning() {
+ EnsureNPerThreadStates(1);
+ TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver;
+ tuning_resolver->SetTuning(explicit_tuning);
+ return tuning_resolver->Resolve();
+ }
+
+ template <Path CompiledPaths>
+ Path GetPathToTake() {
+ last_taken_path =
+ GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths());
+ return last_taken_path;
+ }
+
+ void SetRuntimeEnabledPaths(Path paths);
+ Path GetRuntimeEnabledPaths();
+
+ private:
+ // Allocator for main thread work before invoking the threadpool.
+ // Our simple Allocator does not allow reserving/allocating more blocks
+ // while it's already in committed state, so the main thread needs both
+ // this allocator, and its per-thread allocator.
+ std::unique_ptr<Allocator> main_allocator_;
+ std::unique_ptr<PrepackedCache> prepacked_cache_;
+ Path runtime_enabled_paths_ = Path::kNone;
+};
+
+} // end namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_
diff --git a/ruy/context_test.cc b/ruy/context_test.cc
new file mode 100644
index 0000000..c189030
--- /dev/null
+++ b/ruy/context_test.cc
@@ -0,0 +1,63 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/context.h"
+
+#include "testing/base/public/gunit.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+namespace {
+
+TEST(ContextTest, EnabledPathsGeneral) {
+ ruy::Context ruy_context;
+ const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths();
+ const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths();
+ ASSERT_EQ(ruy_paths, ruy_paths_repeat);
+ EXPECT_NE(ruy_paths, Path::kNone);
+ EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference);
+ EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp);
+}
+
+#if RUY_PLATFORM(X86)
+TEST(ContextTest, EnabledPathsX86) {
+ ruy::Context ruy_context;
+ ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 |
+ Path::kAvx512 | Path::kAvxVnni);
+ const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths();
+ EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone);
+ EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone);
+}
+#endif // RUY_PLATFORM(X86)
+
+#if RUY_PLATFORM(ARM)
+TEST(ContextTest, EnabledPathsArm) {
+ ruy::Context ruy_context;
+ ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod);
+ const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths();
+ EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone);
+ EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone);
+ EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon);
+}
+#endif // RUY_PLATFORM(ARM)
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/cpu_cache_size.h b/ruy/cpu_cache_size.h
new file mode 100644
index 0000000..82f41cc
--- /dev/null
+++ b/ruy/cpu_cache_size.h
@@ -0,0 +1,81 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_
+
+#include "ruy/path.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+
+// LocalDataCacheSize returns a sane default size for each CPU core's local
+// data cache, i.e. the largest data cache that is local to that CPU core, not
+// shared with other cores. That allows coarse tuning of code that aims for
+// most of its memory accesses to hit such a typically fast data cache.
+//
+// SharedDataCacheSize returns a sane default size of the total data cache
+// accessible to each CPU, including any shared cache.
+//
+// For example, if we design tune this code for a ARM Cortex-A55 with a local L1
+// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M,
+// LocalDataCacheSize should return 128k and SharedDataCacheSize
+// should return 1M.
+//
+// Ideally these values would be queried at runtime, and we should probably
+// do that on x86, but that is hard to do on ARM.
+#if RUY_PLATFORM(ARM_64)
+inline int LocalDataCacheSize() { return 1 << 15; }
+inline int SharedDataCacheSize() { return 1 << 19; }
+#elif RUY_PLATFORM(ARM_32)
+inline int LocalDataCacheSize() { return 1 << 14; }
+inline int SharedDataCacheSize() { return 1 << 18; }
+#elif RUY_PLATFORM(X86)
+inline int LocalDataCacheSize() { return 1 << 17; }
+inline int SharedDataCacheSize() { return 1 << 21; }
+#else
+inline int LocalDataCacheSize() { return 1 << 14; }
+inline int SharedDataCacheSize() { return 1 << 18; }
+#endif
+// Variants taking a Path argument which acts
+// as a hint telling whether we're targeting more or less recent/powerful CPUs.
+inline int LocalDataCacheSize(Path path) {
+#if RUY_PLATFORM(ARM_64)
+ if (path == Path::kNeonDotprod) {
+ // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with
+ // 128k L2 local cache.
+ return 1 << 17;
+ }
+#else
+ (void)path;
+#endif
+ return LocalDataCacheSize();
+}
+inline int SharedDataCacheSize(Path path) {
+#if RUY_PLATFORM(ARM_64)
+ if (path == Path::kNeonDotprod) {
+ // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with
+ // 1M L3 shared cache.
+ return 1 << 20;
+ }
+#else
+ (void)path;
+#endif
+ return SharedDataCacheSize();
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_
diff --git a/ruy/detect_arm.cc b/ruy/detect_arm.cc
new file mode 100644
index 0000000..85f7156
--- /dev/null
+++ b/ruy/detect_arm.cc
@@ -0,0 +1,73 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* Detection of dotprod instructions on ARM.
+ * The current Linux-specific code relies on sufficiently new Linux kernels:
+ * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to
+ * a late backport. This was backported just before the Android 10 release, so
+ * this is leaving out pre-release Android 10 builds as well as earlier Android
+ * versions.
+ *
+ * It is possible to detect instructions in other ways that don't rely on
+ * an OS-provided feature identification mechanism:
+ *
+ * (A) We used to have a SIGILL-handler-based method that worked at least
+ * on Linux. Its downsides were (1) crashes on a few devices where
+ * signal handler installation didn't work as intended; (2) additional
+ * complexity to generalize to other Unix-ish operating systems including
+ * iOS; (3) source code complexity and fragility of anything installing
+ * and restoring signal handlers; (4) confusing behavior under a debugger.
+ *
+ * (B) We also experimented with a fork-ing approach where a subprocess
+ * tries the instruction. Compared to (A), this is much simpler and more
+ * reliable and portable, but also much higher latency on Android where
+ * an uncaught signal typically causes a 100 ms latency.
+ *
+ * Should there be interest in either technique again in the future,
+ * code implementing both (A) and (B) can be found in earlier revisions of this
+ * file - in actual code for (A) and in a comment for (B).
+ */
+
+#include "ruy/detect_arm.h"
+
+#if defined __linux__ && defined __aarch64__
+#include <sys/auxv.h>
+#endif
+
+namespace ruy {
+
+namespace {
+
+#if defined __linux__ && defined __aarch64__
+bool DetectDotprodByLinuxAuxvMethod() {
+ // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers,
+ // however we need to support building against older headers for the time
+ // being.
+ const int kLocalHwcapAsimddp = 1 << 20;
+ return getauxval(AT_HWCAP) & kLocalHwcapAsimddp;
+}
+#endif
+
+} // namespace
+
+bool DetectDotprod() {
+#if defined __linux__ && defined __aarch64__
+ return DetectDotprodByLinuxAuxvMethod();
+#endif
+
+ return false;
+}
+
+} // namespace ruy
diff --git a/ruy/detect_arm.h b/ruy/detect_arm.h
new file mode 100644
index 0000000..9a1542d
--- /dev/null
+++ b/ruy/detect_arm.h
@@ -0,0 +1,29 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Temporary dotprod-detection code until we can rely on getauxval.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_
+
+namespace ruy {
+
+// On A64, returns true if the dotprod extension is present.
+// On other architectures, returns false unconditionally.
+bool DetectDotprod();
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_
diff --git a/ruy/detect_x86.cc b/ruy/detect_x86.cc
new file mode 100644
index 0000000..ded37b1
--- /dev/null
+++ b/ruy/detect_x86.cc
@@ -0,0 +1,101 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/detect_x86.h"
+
+#include <cstdint>
+
+#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS)
+#include <immintrin.h> // IWYU pragma: keep
+
+#endif
+
+namespace ruy {
+#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS)
+
+namespace {
+
+// See Intel docs, such as http://goo.gl/c6IkGX.
+inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx,
+ std::uint32_t abcd[4]) {
+ std::uint32_t ebx, edx;
+#if defined(__i386__) && defined(__PIC__)
+ /* in case of PIC under 32-bit EBX cannot be clobbered */
+ asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi"
+ : "=D"(ebx),
+#else
+ asm volatile("cpuid"
+ : "+b"(ebx),
+#endif
+ "+a"(eax), "+c"(ecx), "=d"(edx));
+ abcd[0] = eax;
+ abcd[1] = ebx;
+ abcd[2] = ecx;
+ abcd[3] = edx;
+}
+
+} // namespace
+
+bool DetectCpuSse42() {
+ std::uint32_t abcd[4];
+
+ constexpr std::uint32_t kEcxSse42 = 1u << 20;
+ RunCpuid(1, 0, abcd);
+ const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42;
+
+#ifdef RUY_ENABLE_AMD_CPUID_CHECKS
+ constexpr std::uint32_t kEcxAbm = 1u << 5;
+ RunCpuid(0x80000001, 0, abcd);
+ const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm;
+#else
+ constexpr std::uint32_t kEcxPopcnt = 1u << 23;
+ RunCpuid(1, 0, abcd);
+ const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt;
+#endif
+
+ return has_sse4_2_base && has_extras;
+}
+
+bool DetectCpuAvx2() {
+ constexpr std::uint32_t kEbxAvx2 = 1u << 5;
+ constexpr std::uint32_t kEcxFma = 1u << 12;
+
+ std::uint32_t abcd[4];
+
+ RunCpuid(7, 0, abcd);
+ const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2;
+ RunCpuid(1, 0, abcd);
+ const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma;
+
+ return has_avx2 && has_fma;
+}
+
+bool DetectCpuAvx512() {
+ constexpr std::uint32_t kEbxAvx512F = 1u << 16;
+ constexpr std::uint32_t kEbxAvx512Dq = 1u << 17;
+ constexpr std::uint32_t kEbxAvx512Cd = 1u << 28;
+ constexpr std::uint32_t kEbxAvx512Bw = 1u << 30;
+ constexpr std::uint32_t kEbxAvx512Vl = 1u << 31;
+
+ constexpr std::uint32_t kEbxAvx512Mask =
+ kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl;
+ std::uint32_t abcd[4];
+ RunCpuid(7, 0, abcd);
+
+ return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask;
+}
+
+#endif
+} // namespace ruy
diff --git a/ruy/detect_x86.h b/ruy/detect_x86.h
new file mode 100644
index 0000000..fede7c7
--- /dev/null
+++ b/ruy/detect_x86.h
@@ -0,0 +1,49 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_
+
+#include "ruy/platform.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+#if RUY_PLATFORM(X86_ENHANCEMENTS)
+
+// This also checks ABM support, which implies LZCNT and POPCNT.
+bool DetectCpuSse42();
+bool DetectCpuAvx2();
+bool DetectCpuAvx512();
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// TODO(b/146646451): Introduce and activate.
+inline bool DetectCpuAvxVnni() { return false; }
+
+#else // RUY_PLATFORM(X86_ENHANCEMENTS)
+
+inline bool DetectCpuSse42() { return false; }
+inline bool DetectCpuAvx2() { return false; }
+inline bool DetectCpuAvx512() { return false; }
+inline bool DetectCpuAvxVnni() { return false; }
+
+#endif // !RUY_PLATFORM(X86_ENHANCEMENTS)
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_
diff --git a/ruy/dispatch.h b/ruy/dispatch.h
new file mode 100644
index 0000000..2fd50d0
--- /dev/null
+++ b/ruy/dispatch.h
@@ -0,0 +1,482 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements the translation between Ruy's entry point (ruy::Mul) and
+// the internal implementation of matrix multiplication.
+//
+// The primary elements of this dispatch are:
+// - pick suitable gemm kernel and packing routines for the user-specified
+// CompiledPaths based on the current CPU.
+// - decide on the structure of the packed matrices needed by the internal
+// implementation (see pack.h for more information on packing).
+// - translate the Mul operation into TrMul (see trmul.h for why that is
+// useful). This is done by changing the matrix Layout -- no matrix data is
+// actually moved.
+//
+// This file is also factored to serve as a building block for the advanced API
+// as well.
+//
+// This file also performs some checking of invariants to catch user errors.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <limits> // IWYU pragma: keep
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/context.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/kernel.h"
+#include "ruy/kernel_common.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/spec.h"
+#include "ruy/trmul.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+
+// If the Spec's LayoutSupport covers only some special cases,
+// this function enforces that the matrix multiplication at hand falls into
+// that special case.
+template <typename Spec>
+void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout,
+ const Layout& dst_layout) {
+ if (Spec::kLayoutSupport == LayoutSupport::kRCC) {
+ RUY_DCHECK(IsRowMajor(lhs_layout));
+ RUY_DCHECK(IsColMajor(rhs_layout));
+ RUY_DCHECK(IsColMajor(dst_layout));
+ }
+}
+
+template <typename Scalar>
+bool IsSymmetricZeroPoint(Scalar zero_point) {
+ return zero_point == SymmetricZeroPoint<Scalar>();
+}
+
+template <typename Spec, typename Scalar>
+void CheckZeroPoint(Scalar zero_point) {
+ if (std::is_floating_point<Scalar>::value ||
+ Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) {
+ RUY_DCHECK(IsSymmetricZeroPoint(zero_point));
+ }
+}
+
+template <typename Spec, typename LhsScalar, typename RhsScalar,
+ typename DstScalar>
+void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
+ DstScalar dst_zero_point) {
+ // If the Spec's ZeroPointSupport covers only some special cases,
+ // this function enforces that the matrix multiplication at hand falls into
+ // that special case.
+ CheckZeroPoint<Spec>(lhs_zero_point);
+ CheckZeroPoint<Spec>(rhs_zero_point);
+ CheckZeroPoint<Spec>(dst_zero_point);
+
+ // Guard against the case when both LHS and RHS zero_point's are equal to
+ // the minimum representable value. In that case, padding with zero_point
+ // values will generate the bad case for fast int8 kernels on NEON
+ // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
+ // into a int16: this is safe except in the bad case -128*-128 + -128*-128.
+ // See b/131609283. This only affects the kNeon path but we ban this for all
+ // paths in order for ruy to have the same supported parameter space
+ // on all paths.
+ RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() ||
+ rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
+}
+
+template <typename Spec, typename DstScalar>
+void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) {
+ static_assert(std::is_same<typename Spec::DstScalar, DstScalar>::value, "");
+ if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) return;
+
+ // If user is looking for the raw accumulator, zero_point and all the other
+ // dequantize fields don't make sense and should not be set.
+ RUY_DCHECK_EQ(dst_zero_point, 0);
+ RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits<std::int32_t>::max());
+ RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits<std::int32_t>::min());
+ RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
+ RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
+ RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr);
+ RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr);
+}
+
+inline bool IsColMajorTrMul(const TrMulParams& params) {
+ return IsColMajor(params.src[Side::kLhs].layout) &&
+ IsColMajor(params.src[Side::kRhs].layout) &&
+ IsColMajor(params.dst.layout);
+}
+
+inline void CreatePackedLayout(const Layout& src, const Type& scalar,
+ const KernelLayout& kernel_layout,
+ PackedLayout* packed) {
+ packed->order = Order::kColMajor;
+ packed->rows = round_up_pot(src.rows, kernel_layout.rows);
+ packed->cols = round_up_pot(src.cols, kernel_layout.cols);
+ packed->kernel = kernel_layout;
+ int inner_size = packed->rows;
+ if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) {
+ packed->stride =
+ (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64;
+ } else {
+ packed->stride = inner_size;
+ }
+}
+
+template <typename Scalar, typename PackedScalar>
+void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
+ TrMulParams* params) {
+ // Ruy always uses 32-bit signed accumulators for quantized
+ // matrix multiplication, so we would like to always use std::int32_t
+ // unconditionally for SumsType.
+ // However, for floating point types, we still need a reasonable type here to
+ // avoid tripping assertions elsewhere in the code.
+ using SumsType =
+ typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
+ std::int32_t>::type;
+
+ const DMatrix& src = params->src[side];
+ PMatrix* packed = &params->packed[side];
+ packed->data_type = Type::Create<PackedScalar>();
+ packed->sums_type = Type::Create<SumsType>();
+ CreatePackedLayout(src.layout, packed->data_type, kernel_layout,
+ &packed->layout);
+ packed->zero_point = Pack<PackedScalar, Scalar>(src.zero_point);
+}
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void PopulateTrMulParams(TrMulParams* params) {
+ static_assert((ThePath & Path::kReference) == Path::kNone,
+ "Path::kReference should not do TrMul");
+ // The optimized code paths don't handle the full generality of Ruy's API.
+ // Fall back to Path::kStandardCpp if necessary.
+ bool fallback_to_standard_cpp = false;
+ if (ThePath != Path::kStandardCpp) {
+ // The optimized code paths currently only handle the case of all matrices
+ // being column major.
+ if (!IsColMajorTrMul(*params)) {
+ fallback_to_standard_cpp = true;
+ }
+ }
+
+ if (fallback_to_standard_cpp) {
+ PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
+ Spec>(params);
+ return;
+ }
+
+ using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
+ using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
+ using Kernel =
+ Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, DstScalar, Spec>;
+ using LhsKernelLayout = typename Kernel::LhsLayout;
+ using RhsKernelLayout = typename Kernel::RhsLayout;
+
+ params->path = ThePath;
+
+ params->local_data_cache_size = Spec::local_data_cache_size();
+ params->shared_data_cache_size = Spec::shared_data_cache_size();
+
+ CreatePackedMatrix<LhsScalar, PackedLhsScalar>(
+ Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params);
+ CreatePackedMatrix<RhsScalar, PackedRhsScalar>(
+ Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params);
+ params->run_pack[Side::kLhs] =
+ &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>;
+ params->run_pack[Side::kRhs] =
+ &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>;
+ params->run_kernel =
+ &RunKernel<ThePath, PackedLhsScalar, PackedRhsScalar, DstScalar, Spec>;
+
+ return;
+}
+
+// PopulateTrMulParamsAllCompiledPaths calls into one of multiple
+// instantiations of PopulateTrMulParams. For each bit that is set in
+// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
+// corresponding to that single bit. The call to PopulateTrMulParams is
+// guarded by a runtime check that it is in fact the dynamically selected path.
+//
+// PopulateTrMulParamsAllCompiledPaths is implemented with template
+// metaprogramming by mutual recursion between PathSearchCountdown and
+// PathSearchCompiledPaths.
+//
+// PopulateTrMulParamsAllCompiledPaths is logically implementing the following
+// computation:
+//
+// template <Path CompiledPaths>
+// void PopulateTrMulParamsAllCompiledPaths(Path the_path,
+// TrMulParams* params) {
+// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
+// Path current_path = static_cast<Path>(1 << bit);
+// if ((CompiledPaths & current_path) != Path::kNone) { // [2]
+// if (current_path == the_path) { // [3]
+// PopulateTrMulParams<current_path, ...>(the_path, params);
+// return;
+// }
+// }
+// }
+// }
+//
+//
+//
+// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
+// done in the recursion of PathSearchOnlyCompiledPaths.
+// [2] - Done by PathSearchOnlyCompiledPaths's partial template
+// specialization on InCompiledPaths. This is the check which necessitates
+// doing the whole computation at C++ compile time.
+// [3] - Done by the `if` in the main definition of
+// PathSearchOnlyCompiledPaths.
+//
+// The template metaprogramming is necessary because:
+// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
+// compile-time constant.
+// - PopulateTrMulParamsAllCompiledPaths must not instantiate
+// inner loops for paths that are not in CompiledPaths, since that can result in
+// bogus instantiations which cause a compile time failure.
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename Spec>
+struct PathSearchCountdown;
+
+template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
+ typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+struct PathSearchOnlyCompiledPaths {
+ static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
+ static void Search(Path the_path, TrMulParams* params) {
+ if (kCurrentPath == the_path) {
+ PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, DstScalar, Spec>(
+ params);
+ return;
+ }
+ PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
+ DstScalar, Spec>::Search(the_path, params);
+ }
+};
+
+// Skip this iteration if CompiledPaths doesn't contain the specified path.
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename Spec>
+struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar,
+ RhsScalar, DstScalar, Spec> {
+ static void Search(Path the_path, TrMulParams* params) {
+ PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar,
+ DstScalar, Spec>::Search(the_path, params);
+ }
+};
+
+template <Path CompiledPaths, int BitNumber, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename Spec>
+struct PathSearchCountdown {
+ static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber);
+ static void Search(Path the_path, TrMulParams* params) {
+ PathSearchOnlyCompiledPaths<
+ CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber,
+ LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params);
+ }
+};
+
+// Termination of the countdown. If the counter reaches -1, then we haven't
+// found the specified path.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, DstScalar,
+ Spec> {
+ static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); }
+};
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
+ return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar,
+ RhsScalar, DstScalar, Spec>::Search(the_path,
+ params);
+}
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void CreateTrMulParams(const Matrix<LhsScalar>& lhs,
+ const Matrix<RhsScalar>& rhs, const Spec& spec,
+ Context* context, Matrix<DstScalar>* dst, Path the_path,
+ TrMulParams* params) {
+ // Fill in the fields we already know.
+ params->src[Side::kLhs] = ToDMatrix(lhs);
+ params->src[Side::kRhs] = ToDMatrix(rhs);
+ params->dst = ToDMatrix(*dst);
+ params->spec = ToVoidPtr(&spec);
+
+ // Create inner loops and packed matrices based on the Path.
+ PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar,
+ DstScalar, Spec>(the_path, params);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Matrix<DstScalar>* dst) {
+ profiler::ScopeLabel label("ReferenceMul");
+ for (int i = 0; i < lhs.layout.rows; i++) {
+ for (int j = 0; j < rhs.layout.cols; j++) {
+ using AccumScalar = typename Spec::AccumScalar;
+ AccumScalar accum = 0;
+ for (int k = 0; k < lhs.layout.cols; k++) {
+ AccumScalar lhs_val = Element(lhs, i, k);
+ AccumScalar rhs_val = Element(rhs, k, j);
+ accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point);
+ }
+ if (spec.bias) {
+ accum += spec.bias[i];
+ }
+ ApplyMultiplier(spec, i, &accum);
+ accum += dst->zero_point;
+ accum = std::min<AccumScalar>(accum, spec.clamp_max);
+ accum = std::max<AccumScalar>(accum, spec.clamp_min);
+ *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
+ }
+ }
+}
+
+// Compile-time dispatch to ReferenceMul. This allows us to statically ensure
+// that there is no call to ReferenceMul in the user's binary.
+template <bool ReferenceMulIsEnabled>
+struct CompileTimeEnabledReferenceMul {
+ template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+ static void Run(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Matrix<DstScalar>* dst) {
+ ReferenceMul(lhs, rhs, spec, dst);
+ }
+};
+
+// When this partial specialization is chosen, it ensures that ReferenceMul
+// is never compiled.
+template <>
+struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
+ template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+ static void Run(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Matrix<DstScalar>* dst) {
+ RUY_DCHECK(false);
+ }
+};
+
+inline void HandlePrepackedCaching(TrMulParams* params,
+ const SidePair<bool>& cacheable,
+ Context* context) {
+ if (context->cache_policy == CachePolicy::kNoCache) {
+ return;
+ }
+
+ if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) {
+ // TODO(b/149304278) Cache on dst.cols <= selected kernel width.
+ if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) {
+ return;
+ }
+ PrepackedCache* prepacked_cache = context->GetPrepackedCache();
+ auto cache_key = std::make_pair(reinterpret_cast<void*>(params->run_kernel),
+ params->src[Side::kLhs].data);
+ auto it = prepacked_cache->FindAndUpdate(cache_key);
+ if (it != prepacked_cache->cend()) {
+ params->packed[Side::kLhs].data = it->second.first.data;
+ params->packed[Side::kLhs].sums = it->second.first.sums;
+ params->is_prepacked[Side::kLhs] = true;
+ return;
+ }
+
+ // Allocate the prepacked matrix.
+ PrepackedMatrix prepacked_lhs;
+ prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]);
+ prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]);
+ prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs);
+ params->packed[Side::kLhs].data = prepacked_lhs.data;
+ params->packed[Side::kLhs].sums = prepacked_lhs.sums;
+ params->is_prepacked[Side::kLhs] = true;
+ Tuning tuning = context->GetMainThreadTuning();
+ params->RunPack(Side::kLhs, tuning, 0,
+ params->packed[Side::kLhs].layout.cols);
+ prepacked_cache->Insert(cache_key, prepacked_lhs);
+ return;
+ }
+}
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Context* context, Matrix<DstScalar>* dst) {
+ static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path");
+ static_assert((CompiledPaths & ~kAllPaths) == Path::kNone,
+ "CompiledPaths must be a subset of ruy::kAllPaths");
+
+ profiler::ScopeLabel mul_label("Mul");
+ profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d",
+ lhs.layout.rows, lhs.layout.cols,
+ rhs.layout.cols);
+
+ EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
+ EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
+ dst->zero_point);
+ EnforceDstSpecSupport<Spec>(spec, dst->zero_point);
+
+ // This should be a constant, for a given machine and CompiledPaths.
+ // There is a back door to override it for testing, but in production it will
+ // always be the "best" Path. I.e. the one with the newest SIMD instructions
+ // available on the present machine, and avoiding Path::kReference unless
+ // no other path is compiled.
+ //
+ // Unfortunately, it is not a *static* constant, since it depends on runtime
+ // detection of the available SIMD instructions.
+ Path the_path = context->GetPathToTake<CompiledPaths>();
+
+ // Production code should probably never execute Path::kReference.
+ // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if
+ // that's what we need to do, then get it out of the way before going down the
+ // TrMul path.
+ if (the_path == Path::kReference) {
+ constexpr bool ReferenceMulIsEnabled =
+ (CompiledPaths & Path::kReference) != Path::kNone;
+ CompileTimeEnabledReferenceMul<ReferenceMulIsEnabled>::Run(lhs, rhs, spec,
+ dst);
+ return;
+ }
+
+ // As described in the comment at the top of this file, Ruy internally
+ // converts Mul into TrMul. We handle that here.
+ //
+ // This is Ruy's main code path.
+ constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
+ Matrix<LhsScalar> transposed_lhs(lhs);
+ Transpose(&transposed_lhs);
+ TrMulParams params;
+ CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
+ the_path, &params);
+ SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
+ HandlePrepackedCaching(&params, cacheable, context);
+ TrMul(&params, context);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_
diff --git a/ruy/example.cc b/ruy/example.cc
new file mode 100644
index 0000000..3b42c97
--- /dev/null
+++ b/ruy/example.cc
@@ -0,0 +1,136 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <iostream>
+
+#include "ruy/ruy.h"
+
+void ExampleMulFloat(ruy::Context *context) {
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2, 3, 4};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+
+ ruy::BasicSpec<float, float> spec;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst);
+
+ std::cout << "Example Mul, float:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) {
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2, 3, 4};
+ const float bias_data[] = {1, 0};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+
+ ruy::BasicSpec<float, float> spec;
+ spec.bias = bias_data;
+ spec.clamp_min = 0;
+ spec.clamp_max = 15;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst);
+
+ std::cout << "Example Mul, float with bias addition and clamp:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) {
+ const std::uint8_t lhs_data[] = {124, 125, 126, 127};
+ const std::uint8_t rhs_data[] = {129, 130, 131, 132};
+ std::uint8_t dst_data[4];
+
+ ruy::Matrix<std::uint8_t> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ lhs.zero_point = 125;
+ ruy::Matrix<std::uint8_t> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ rhs.zero_point = 132;
+ ruy::Matrix<std::uint8_t> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+ dst.zero_point = 129;
+
+ ruy::BasicSpec<std::int32_t, std::uint8_t> spec;
+ spec.multiplier_fixedpoint = 1 << 30;
+
+ spec.multiplier_exponent = 0;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst);
+
+ std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+void ExampleMulInt8PerChannelQuantized(ruy::Context *context) {
+ const std::int8_t lhs_data[] = {1, 2, 3, 4};
+ const std::int8_t rhs_data[] = {1, 2, 3, 4};
+ const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
+ const int exponent_data[] = {1, -2};
+ std::int8_t dst_data[4];
+
+ ruy::Matrix<std::int8_t> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ ruy::Matrix<std::int8_t> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ ruy::Matrix<std::int8_t> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+
+ ruy::BasicSpec<std::int32_t, std::int8_t> spec;
+ spec.multiplier_fixedpoint_perchannel = multiplier_data;
+ spec.multiplier_exponent_perchannel = exponent_data;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst);
+
+ std::cout << "Example Mul, int8 quantized with per-channel multipliers\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+int main() {
+ ruy::Context context;
+ ExampleMulFloat(&context);
+ ExampleMulFloatWithBiasAddAndClamp(&context);
+ ExampleMulUint8AsymmetricQuantized(&context);
+ ExampleMulInt8PerChannelQuantized(&context);
+}
diff --git a/ruy/example_advanced.cc b/ruy/example_advanced.cc
new file mode 100644
index 0000000..9041bdb
--- /dev/null
+++ b/ruy/example_advanced.cc
@@ -0,0 +1,83 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstddef>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+#include "ruy/ruy_advanced.h"
+
+// Simple allocator for allocating pre-packed matrices.
+class SimpleAllocator {
+ public:
+ void* AllocateBytes(std::size_t num_bytes) {
+ char* p = new char[num_bytes];
+ buffers_.emplace_back(p);
+ return static_cast<void*>(p);
+ }
+
+ private:
+ std::vector<std::unique_ptr<char[]>> buffers_;
+};
+
+void ExamplePrepack(ruy::Context* context) {
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2, 3, 4};
+ float dst_data[4];
+
+ // Set up the matrix layouts and spec.
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
+ ruy::BasicSpec<float, float> spec;
+
+ SimpleAllocator allocator;
+ auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* {
+ return allocator.AllocateBytes(num_bytes);
+ };
+
+ // In this example, we pre-pack only the RHS, but either will work.
+ // Note that we only need to set the data pointer for the matrix we are
+ // pre-packing.
+ ruy::PrepackedMatrix prepacked_rhs;
+ rhs.data = rhs_data;
+ ruy::PrePackForMul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst,
+ /*prepacked_lhs=*/nullptr, &prepacked_rhs,
+ alloc_fn);
+
+ // No data will be read from the RHS input matrix when using a pre-packed RHS.
+ rhs.data = nullptr;
+ lhs.data = lhs_data;
+ dst.data = dst_data;
+ ruy::MulWithPrepacked<ruy::kAllPaths>(lhs, rhs, spec, context, &dst,
+ /*prepacked_lhs=*/nullptr,
+ &prepacked_rhs);
+ rhs.data = rhs_data;
+
+ // Print out the results.
+ std::cout << "Example Mul with pre-packing RHS, float:\n";
+ std::cout << "LHS:\n" << lhs;
+ std::cout << "RHS:\n" << rhs;
+ std::cout << "Result:\n" << dst << "\n";
+}
+
+int main() {
+ ruy::Context context;
+ ExamplePrepack(&context);
+}
diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h
new file mode 100644
index 0000000..8913965
--- /dev/null
+++ b/ruy/have_built_path_for.h
@@ -0,0 +1,32 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_
+
+#include "ruy/platform.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+bool HaveBuiltPathForSse42();
+bool HaveBuiltPathForAvx2();
+bool HaveBuiltPathForAvx512();
+bool HaveBuiltPathForAvxVnni();
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_
diff --git a/ruy/have_built_path_for_avx2.cc b/ruy/have_built_path_for_avx2.cc
new file mode 100644
index 0000000..ceca8a4
--- /dev/null
+++ b/ruy/have_built_path_for_avx2.cc
@@ -0,0 +1,35 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+bool HaveBuiltPathForAvx2() { return false; }
+
+#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+bool HaveBuiltPathForAvx2() { return true; }
+
+#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
diff --git a/ruy/have_built_path_for_avx512.cc b/ruy/have_built_path_for_avx512.cc
new file mode 100644
index 0000000..15fba62
--- /dev/null
+++ b/ruy/have_built_path_for_avx512.cc
@@ -0,0 +1,35 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+bool HaveBuiltPathForAvx512() { return false; }
+
+#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+bool HaveBuiltPathForAvx512() { return true; }
+
+#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
diff --git a/ruy/have_built_path_for_avxvnni.cc b/ruy/have_built_path_for_avxvnni.cc
new file mode 100644
index 0000000..68ef2a2
--- /dev/null
+++ b/ruy/have_built_path_for_avxvnni.cc
@@ -0,0 +1,39 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+bool HaveBuiltPathForAvxVnni() { return false; }
+
+#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+bool HaveBuiltPathForAvxVnni() { return true; }
+
+#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
diff --git a/ruy/have_built_path_for_sse42.cc b/ruy/have_built_path_for_sse42.cc
new file mode 100644
index 0000000..2141b75
--- /dev/null
+++ b/ruy/have_built_path_for_sse42.cc
@@ -0,0 +1,39 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/have_built_path_for.h"
+#include "ruy/opt_set.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// IMPORTANT:
+// These patterns must match those in the pack and kernel cc files.
+#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+bool HaveBuiltPathForSse42() { return false; }
+
+#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+bool HaveBuiltPathForSse42() { return true; }
+
+#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
diff --git a/ruy/internal_matrix.h b/ruy/internal_matrix.h
new file mode 100644
index 0000000..7fe13be
--- /dev/null
+++ b/ruy/internal_matrix.h
@@ -0,0 +1,388 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Internal types and helpers for matrices.
+//
+// Ruy has a couple slightly different notions of matrices, besides the
+// Matrix<T> class that we expose to the user-facing API.
+//
+// TODO(silvasean): Put parts of this architecture description somewhere more
+// prominent.
+//
+// The 4 main matrix types are:
+// - Matrix<T>: This is a user-facing type on Ruy's external API boundary. It is
+// also used internally.
+// - DMatrix: This is a type-erased version of Matrix<T>. "D" = "dynamic".
+// - PMatrix: This represents a packed matrix, which requires tracking kernel
+// layout and row/column sums for quantization. It is type-erased.
+// - PackedMatrix<T>: This is a statically typed variant of PMatrix for
+// convenience inside typed routines.
+//
+// Note that Matrix<T> is *not* implemented in terms of the internal types. It
+// is an independent, simple, and user-facing type.
+//
+// The use of type-erasure might seem surprising for a library like Ruy with a
+// heavily-templated entry point, but it is motivated by the desire for most of
+// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3
+// main parts:
+// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry
+// point, along with routines that select RunKernel and RunPack implementations
+// statically based on those template parameters.
+// - "back-end" (kernel.h, pack.h)- this consists of the implementations of
+// RunKernel and RunPack, often in assembly code, which are the building blocks
+// that Ruy calls to perform matrix multiplication. These are templated so that
+// only the requested types/Path's are actually emitted by the compiler.
+// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the
+// calls to the "back-end" optimized building blocks. This layer has to deal
+// with issues like cache locality and low-overhead multi-threading.
+//
+// There is a desire for the "middle-end" to be non-templated in order to
+// simplify the implementation and reduce code-size. We type-erase when going
+// from the "front-end" to the "middle-end", and un-type-erase going from the
+// "middle-end" to the "back-end". The un-type-erasure is possible because the
+// "front-end" is responsible for instantiating the needed "back-end" templates,
+// and thus the static type information is still present.
+//
+// Each layer of Ruy uses matrix types:
+// - "front-end": Matrix<T>
+// - "middle-end": DMatrix, PMatrix
+// - "back-end": Matrix<T>, PackedMatrix<T>
+//
+// The use of separate types for packed matrices is not essential, but makes it
+// obvious at a glance whether a matrix is a packed matrix or not. We would
+// reconsider this decision if there was significant duplication between packed
+// and unpacked matrices, but that doesn't seem to be the case at the moment.
+//
+// Another goal is to keep the user-facing Matrix<T> as simple and
+// understandable as possible. Ideally, a user should be able to read the struct
+// definition for Matrix<T> and see a very simple definition with no internal
+// details like sums and kernel block layout.
+//
+// To present another structured view of our various matrix types, here's a
+// table:
+// Plain matrices Packed matrices
+// +----------------------------------
+// Templated | Matrix<T> PackedMatrix<T>
+// Type-erased | DMatrix PMatrix
+//
+//
+// There is 1 additional matrix type not mentioned above, due to its low
+// importance:
+// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare
+// minimum of fields needed for representing the raw data and sums buffers of a
+// packed matrix for the "advanced" explicit pre-packing API. This type plays no
+// role in Ruy's internals and can generally by ignored. The only reason it
+// exists is so that PMatrix is not exposed to users -- we prefer to keep the
+// internal matrix types hidden, even from "advanced" users.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_
+
+#include <cstddef>
+#include <cstdint>
+#include <type_traits>
+#include <utility>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/matrix.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+// KernelLayout describes small-scale block structure in a packed matrix layout.
+// It's a runtime (as opposed to compile-time-constant) version of the
+// FixedKernelLayout struct used to declare kernel layouts.
+//
+// This is is sometimes known as "tiling" in other contexts.
+//
+// For example, consider a packed matrix in column-major format with a
+// column-major KernelLayout. The matrix logically has a shape of
+// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array
+// of shape `[cols / kcols, rows / krows, kcols, krows]`.
+//
+// Note that in the case of kcols=1, krows=1, this degenerates to
+// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block
+// structure.
+struct KernelLayout {
+ Order order = Order::kColMajor;
+ std::uint8_t rows = 1;
+ std::uint8_t cols = 1;
+};
+
+// A packed matrix has a small-scale block structure that is not present in in
+// the input matrices. This block structure is necessary for the kernels to
+// process data efficiently.
+//
+// This struct is very similar to Layout, but has the extra KernelLayout field.
+struct PackedLayout {
+ std::int32_t rows = 0;
+ std::int32_t cols = 0;
+ // Stride is the offset between two adjacent matrix elements
+ // in the non-contiguous direction.
+ std::int32_t stride = 0;
+ Order order = Order::kColMajor;
+ // Small scale layout shuffling, potentially departing from
+ // linear row-major or column-major storage. See KernelLayout.
+ KernelLayout kernel;
+};
+
+// Dynamic representation for a type.
+//
+// The most important field in this struct is the size, which Ruy uses to know
+// how much memory to allocate without having to be templated on a type.
+// Signed-ness and floating-point-ness are mainly present as debugging checks.
+//
+// Note: Ruy does not use this struct to to dynamically dispatch between
+// different typed implementations. As described in the comment at the top of
+// this file, Ruy's "front-end", which is templated, instantiates all the
+// necessary "back-end" routines with complete static knowledge of all the
+// types.
+struct Type {
+ template <typename T>
+ static Type Create() {
+ Type ret;
+ ret.is_signed = std::is_signed<T>::value;
+ ret.is_floating_point = std::is_floating_point<T>::value;
+ ret.size = sizeof(T);
+ return ret;
+ }
+
+ template <typename T>
+ void AssertIs() const {
+ RUY_DCHECK_EQ(is_signed, Create<T>().is_signed);
+ RUY_DCHECK_EQ(is_floating_point, Create<T>().is_floating_point);
+ RUY_DCHECK_EQ(size, Create<T>().size);
+ }
+
+ bool is_signed = false;
+ bool is_floating_point = false;
+ std::uint8_t size = 0;
+};
+
+// Type-erased matrix.
+struct DMatrix {
+ Type data_type;
+ void* data = nullptr;
+ Layout layout;
+ std::int32_t zero_point = 0;
+};
+
+// Type-erased packed matrix.
+struct PMatrix {
+ Type data_type;
+ void* data = nullptr;
+ Type sums_type;
+ void* sums = nullptr;
+ PackedLayout layout;
+ std::int32_t zero_point = 0;
+};
+
+// Convenient typed helper for packed matrices.
+template <typename Scalar>
+struct PackedMatrix {
+ // The row/column sums needed for quantized matrix multiplication when
+ // the opposite operand of the multiplication uses a non-symmetric zero
+ // point.
+ // This member is only relevant for packed matrices.
+ // Additionally, Ruy always uses 32-bit signed accumulators for quantized
+ // matrix multiplication.
+ // For floating point types, there is no quantization, so this pointer
+ // will always be null. We still need code referencing it to compile
+ // though, even if it is always branched around. Hence we use Scalar*
+ // itself as the type in that case.
+ using SumsType =
+ typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
+ std::int32_t>::type;
+
+ Scalar* data = nullptr;
+ SumsType* sums = nullptr;
+ PackedLayout layout;
+ std::int32_t zero_point = 0;
+};
+
+template <typename T>
+DMatrix ToDMatrix(const Matrix<T>& matrix) {
+ DMatrix ret;
+ ret.data_type = Type::Create<T>();
+ ret.data = ToVoidPtr(matrix.data.get());
+ ret.layout = matrix.layout;
+ ret.zero_point = matrix.zero_point;
+ return ret;
+}
+
+template <typename T>
+Matrix<T> ToMatrix(const DMatrix& dmatrix) {
+ dmatrix.data_type.AssertIs<T>();
+ Matrix<T> ret;
+ ret.data = static_cast<T*>(dmatrix.data);
+ ret.layout = dmatrix.layout;
+ ret.zero_point = dmatrix.zero_point;
+ return ret;
+}
+
+template <typename T>
+PackedMatrix<T> ToPackedMatrix(const PMatrix& pmatrix) {
+ using SumsType = typename PackedMatrix<T>::SumsType;
+ pmatrix.data_type.AssertIs<T>();
+ pmatrix.sums_type.AssertIs<SumsType>();
+ PackedMatrix<T> ret;
+ ret.data = static_cast<T*>(pmatrix.data);
+ ret.sums = static_cast<SumsType*>(pmatrix.sums);
+ ret.layout = pmatrix.layout;
+ ret.zero_point = pmatrix.zero_point;
+ return ret;
+}
+
+// Helpers for Layout / PackedLayout.
+
+inline bool IsPacked(const Layout& layout) {
+ if (layout.order == Order::kColMajor) {
+ return layout.stride == layout.rows;
+ } else {
+ return layout.stride == layout.cols;
+ }
+}
+
+inline bool IsRowMajor(const Layout& layout) {
+ return layout.order == Order::kRowMajor;
+}
+
+template <typename LayoutOrPackedLayout>
+inline bool IsColMajor(const LayoutOrPackedLayout& layout) {
+ return layout.order == Order::kColMajor;
+}
+
+template <typename LayoutOrPackedLayout>
+inline int FlatSize(const LayoutOrPackedLayout& layout) {
+ const int outerdim =
+ layout.order == Order::kColMajor ? layout.cols : layout.rows;
+ return layout.stride * outerdim;
+}
+
+// TODO(b/130417400) add a unit test
+inline int Offset(const Layout& layout, int row, int col) {
+ // TODO(benoitjacob) - should check this but this make the _slow tests take
+ // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
+ // bypassing the check?
+ // RUY_DCHECK_GE(row, 0);
+ // RUY_DCHECK_GE(col, 0);
+ // RUY_DCHECK_LT(row, layout.rows);
+ // RUY_DCHECK_LT(col, layout.cols);
+ int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride;
+ int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride;
+ return row * row_stride + col * col_stride;
+}
+
+// TODO(b/130417400) add a unit test
+inline int Offset(const PackedLayout& layout, int row, int col) {
+ RUY_DCHECK(is_pot(layout.kernel.rows));
+ RUY_DCHECK(is_pot(layout.kernel.cols));
+ int row_outer = row & ~(layout.kernel.rows - 1);
+ int col_outer = col & ~(layout.kernel.cols - 1);
+ int row_stride_outer =
+ layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride;
+ int col_stride_outer =
+ layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride;
+ int offset_outer =
+ row_outer * row_stride_outer + col_outer * col_stride_outer;
+ int row_inner = row - row_outer;
+ int col_inner = col - col_outer;
+ int row_stride_inner =
+ layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols;
+ int col_stride_inner =
+ layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows;
+ int offset_inner =
+ row_inner * row_stride_inner + col_inner * col_stride_inner;
+ return offset_outer + offset_inner;
+}
+
+// Helpers for Matrix<T>.
+
+template <typename Scalar>
+const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
+ return mat.data.get() + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
+ return mat->data.get() + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
+// Helpers for PackedMatrix<T>.
+// Duplicated from Matrix<T>, but the duplication seems acceptable.
+
+template <typename Scalar>
+const Scalar* ElementPtr(const PackedMatrix<Scalar>& mat, int row, int col) {
+ return mat.data + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(PackedMatrix<Scalar>* mat, int row, int col) {
+ return mat->data + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const PackedMatrix<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
+// Helpers for PMatrix.
+
+inline std::size_t DataSize(const PMatrix& packed) {
+ return FlatSize(packed.layout) * packed.data_type.size;
+}
+
+inline std::size_t SumsSize(const PMatrix& packed) {
+ // Packed matrices are only relevant for Ruy's TrMul implementations. For
+ // TrMul, the number of sums is always equal to the number of columns.
+ return packed.layout.cols * packed.sums_type.size;
+}
+
+// Transpose helpers.
+
+inline void Transpose(Order* order) {
+ *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor;
+}
+
+inline void Transpose(Layout* layout) {
+ Transpose(&layout->order);
+ std::swap(layout->rows, layout->cols);
+}
+
+template <typename Scalar>
+inline void Transpose(Matrix<Scalar>* matrix) {
+ Transpose(&matrix->layout);
+}
+
+// Helpers for KernelLayout.
+
+template <typename FixedKernelLayout>
+KernelLayout ToKernelLayout() {
+ KernelLayout ret;
+ ret.order = FixedKernelLayout::kOrder;
+ ret.rows = FixedKernelLayout::kRows;
+ ret.cols = FixedKernelLayout::kCols;
+ return ret;
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_
diff --git a/ruy/kernel.h b/ruy/kernel.h
new file mode 100644
index 0000000..d7930b4
--- /dev/null
+++ b/ruy/kernel.h
@@ -0,0 +1,31 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
+
+#include "ruy/platform.h"
+
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM(NEON)
+#include "ruy/kernel_arm.h"
+#elif RUY_PLATFORM(X86)
+#include "ruy/kernel_x86.h"
+#else
+#include "ruy/kernel_common.h"
+#endif
+// IWYU pragma: end_exports
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h
new file mode 100644
index 0000000..408c23a
--- /dev/null
+++ b/ruy/kernel_arm.h
@@ -0,0 +1,211 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_
+
+#include <cstddef>
+#include <cstdint>
+
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/kernel_common.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/spec.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64)
+void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params);
+#elif RUY_PLATFORM(NEON_32)
+void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params);
+void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params);
+#endif
+void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params);
+void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params);
+void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params);
+
+#if RUY_PLATFORM(NEON_64)
+template <typename DstScalar>
+struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ Tuning tuning = Tuning::kAuto;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitNeonOutOfOrder1Col(params);
+ return;
+ }
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ Kernel8bitNeonInOrder(params);
+ } else {
+ Kernel8bitNeonOutOfOrder(params);
+ }
+ }
+};
+#endif
+
+#if RUY_PLATFORM(NEON_32)
+template <typename DstScalar>
+struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
+ Tuning tuning = Tuning::kAuto;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitNeonOutOfOrder1Col(params);
+ return;
+ }
+ Kernel8bitNeonOutOfOrder(params);
+ }
+};
+#endif
+
+#if RUY_PLATFORM(NEON_64)
+template <typename DstScalar>
+struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitNeonDotprodOutOfOrder1Col(params);
+ } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ Kernel8bitNeonDotprodInOrder(params);
+ } else {
+ Kernel8bitNeonDotprodOutOfOrder(params);
+ }
+ }
+};
+#endif
+
+void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params);
+void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params);
+void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params);
+void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params);
+
+#if RUY_PLATFORM(NEON_64)
+// A Float kernel for ARM64 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ KernelFloatNeonInOrder(params);
+ } else {
+ KernelFloatNeonOutOfOrder(params);
+ }
+ }
+};
+#endif
+
+#if RUY_PLATFORM(NEON_32)
+// A Float kernel for ARM32 Neon.
+template <>
+struct Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<8, 4> params;
+
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+
+ KernelFloat32NeonOutOfOrder(params);
+ }
+};
+#endif
+
+// While the dotprod NEON extension does not concern floating-point arithmetic,
+// its presence allows us to distinguish, in the in-order tuning case, between
+// A53 and A55r1. TODO: should this be folded into tuning?
+template <>
+struct Kernel<Path::kNeonDotprod, float, float, float,
+ BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using Base =
+ Kernel<Path::kNeon, float, float, float, BasicSpec<float, float>>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ KernelFloatNeonDotprodInOrder(params);
+ } else {
+ KernelFloatNeonOutOfOrder(params);
+ }
+ }
+};
+
+#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
new file mode 100644
index 0000000..d537cfe
--- /dev/null
+++ b/ruy/kernel_arm32.cc
@@ -0,0 +1,2499 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#define RUY_ASM_LABEL_STORE_UINT8 91
+#define RUY_ASM_LABEL_STORE_INT8 92
+#define RUY_ASM_LABEL_STORE_INT16 93
+#define RUY_ASM_LABEL_STORE_INT32 94
+#define RUY_ASM_LABEL_AFTER_STORE 99
+
+#define RUY_OFFSET_LHS_BASE_PTR 0
+#define RUY_OFFSET_RHS_BASE_PTR 4
+#define RUY_OFFSET_DST_BASE_PTR 8
+#define RUY_OFFSET_BIAS 12
+#define RUY_OFFSET_START_ROW 16
+#define RUY_OFFSET_START_COL 20
+#define RUY_OFFSET_LAST_ROW 24
+#define RUY_OFFSET_LAST_COL 28
+#define RUY_OFFSET_DST_ROWS 32
+#define RUY_OFFSET_DST_COLS 36
+#define RUY_OFFSET_LHS_STRIDE 40
+#define RUY_OFFSET_RHS_STRIDE 44
+#define RUY_OFFSET_DST_STRIDE 48
+#define RUY_OFFSET_DEPTH 52
+#define RUY_OFFSET_CLAMP_MIN 56
+#define RUY_OFFSET_CLAMP_MAX 60
+#define RUY_OFFSET_FLAGS 64
+
+#define RUY_STACK_OFFSET_SIZE 96
+#define RUY_STACK_OFFSET_DST_COL_PTR 0
+#define RUY_STACK_OFFSET_DST_PTR 16
+#define RUY_STACK_OFFSET_ROW 32
+#define RUY_STACK_OFFSET_COL 48
+#define RUY_STACK_OFFSET_LHS_COL_PTR 64
+#define RUY_STACK_OFFSET_RHS_COL_PTR 80
+
+template <typename Params>
+void CheckOffsetsInKernelParamsFloat32(const Params&) {
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
+ static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+}
+
+// Float kernel for ARM32 out-of-order cores.
+// Just like Float 64 version, except accumulate in to 8x4 block to only
+// use 16 128-bit NEON registers. This is a "first pass" kernel and not
+// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
+void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) {
+ CheckOffsetsInKernelParamsFloat32(params);
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ const float* lhs_ptr = params.lhs_base_ptr;
+ const float* rhs_ptr = params.rhs_base_ptr;
+ // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
+ // each composed of two 64-bit "d" registers. The asm kernel below has the
+ // following NEON register allocation:
+ // Registers q3 -- q10 are accumulators. During accumulation,
+ // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
+ // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
+ // of RHS, like this:
+
+ // Register layout in "q" registers:
+ // RHS 1x4 block
+ // /--------------------------\
+ // |q2.s[0] ... q2.s[3] |
+ // \--------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /--------------------- \
+ // | q0.s[0] | | q3.s[0] ... q9.s[0] |
+ // | ... | | ... ... |
+ // | q0.s[3] | | q3.s[3] q9.s[3] |
+ // | q1.s[0] | | q4.s[0] q10.s[0] |
+ // | ... | | ... ... ... |
+ // | q1.s[3] | | q4.s[3] .. q10.s[3] |
+ // \---------------------/ \--------------------------/
+ // accumulators 8x4 block
+ // q11, q14, q15 currently unused. q12 and q13 are used to load
+ // parameters used for the post-accumulation part of the kernel.
+ // For completeness, here is the register layout in "d" registers:
+ // RHS 1x4 block
+ // /--------------------------\
+ // |d4[0] ... d5[1] |
+ // \--------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /--------------------------\
+ // | d0[0] | | d6[0] ... d18[0] |
+ // | ... | | ... ... |
+ // | d1[1] | | d7[1] d19[1] |
+ // | d2[0] | | d8[0] d20[0] |
+ // | ... | | ... ... ... |
+ // | d3[1] | | d9[1] ... d21[1] |
+ // \---------------------/ \--------------------------/
+ // accumulators 8x4 block
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
+
+ // clang-format off
+
+ // Load the first 32 bytes of LHS and RHS data.
+ // Load q0, q1
+ "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ // Load q2
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q3)
+ RUY_MAKE_ZERO(q4)
+ RUY_MAKE_ZERO(q5)
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov r1, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Accumulation loop
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r2\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ "vmla.f32 q3, q0, d4[0]\n"
+ "vmla.f32 q5, q0, d4[1]\n"
+ "vmla.f32 q7, q0, d5[0]\n"
+ "vmla.f32 q9, q0, d5[1]\n"
+ "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ "vmla.f32 q4, q1, d4[0]\n"
+ "vmla.f32 q6, q1, d4[1]\n"
+ "vmla.f32 q8, q1, d5[0]\n"
+ "vmla.f32 q10, q1, d5[1]\n"
+ "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ "add r1, r1, #1\n"
+ "cmp r1, r2\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "vmla.f32 q3, q0, d4[0]\n"
+ "vmla.f32 q5, q0, d4[1]\n"
+ "vmla.f32 q7, q0, d5[0]\n"
+ "vmla.f32 q9, q0, d5[1]\n"
+
+ "vmla.f32 q4, q1, d4[0]\n"
+ "vmla.f32 q6, q1, d4[1]\n"
+ "vmla.f32 q8, q1, d5[0]\n"
+ "vmla.f32 q10, q1, d5[1]\n"
+
+ // End of accumulation. The registers q3 -- q10 contain the final
+ // float32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final values from these accumulators
+ // and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #3\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #2\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "add r5, r1, r8, lsl #2\n"
+
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "movne r1, r5\n"
+
+ // Load 8 bias values.
+ "vld1.32 {d24, d25, d26, d27}, [r1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
+ // in the rest of the work on the current block.
+ // Load q0, q1
+ "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ // Load q2
+ "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "vadd.f32 q3, q3, q12\n"
+ "vadd.f32 q4, q4, q13\n"
+ "vadd.f32 q5, q5, q12\n"
+ "vadd.f32 q6, q6, q13\n"
+ "vadd.f32 q7, q7, q12\n"
+ "vadd.f32 q8, q8, q13\n"
+ "vadd.f32 q9, q9, q12\n"
+ "vadd.f32 q10, q10, q13\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.32 q12, r2\n" // clamp_min
+ "vdup.32 q13, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.f32 q3, q3, q12\n"
+ "vmax.f32 q4, q4, q12\n"
+ "vmax.f32 q5, q5, q12\n"
+ "vmax.f32 q6, q6, q12\n"
+ "vmax.f32 q7, q7, q12\n"
+ "vmax.f32 q8, q8, q12\n"
+ "vmax.f32 q9, q9, q12\n"
+ "vmax.f32 q10, q10, q12\n"
+
+ // Apply the clamp_max bound
+ "vmin.f32 q3, q3, q13\n"
+ "vmin.f32 q4, q4, q13\n"
+ "vmin.f32 q5, q5, q13\n"
+ "vmin.f32 q6, q6, q13\n"
+ "vmin.f32 q7, q7, q13\n"
+ "vmin.f32 q8, q8, q13\n"
+ "vmin.f32 q9, q9, q13\n"
+ "vmin.f32 q10, q10, q13\n"
+
+ // Compute how much of the 8x4 block of destination values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x4, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #8\n"
+ "mov r5, #4\n"
+ "cmp r1, #8\n"
+ // Compute r1 = how many rows of the 8x4 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+ "cmp r2, #4\n"
+ // Compute r2 = how many cols of the 8x4 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ // Yes, all of the 8x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x4 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x4 block fits.
+ // Set (r3 address, r4 stride) to write directly to destination matrix.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+ "31:\n"
+
+ // Write our float values to the destination described by
+ // (r3 address, r4 stride)
+ "vst1.32 {d6, d7, d8, d9}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q3)
+ RUY_MAKE_ZERO(q4)
+ "vst1.32 {d10, d11, d12, d13}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q5)
+ RUY_MAKE_ZERO(q6)
+ "vst1.32 {d14, d15, d16, d17}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ "vst1.32 {d18, d19, d20, d21}, [r3]\n"
+ "add r3, r3, r4\n"
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+
+ // If all of the 8x4 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #32\n"
+ "add r4, r4, r8\n"
+ // r2 = how many cols of the 8x4 block fit
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #32\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used r3, r5, r10 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #8\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
+ "add r1, r1, r8, lsl #2\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov r1, #1\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13");
+}
+
+#undef RUY_MAKE_ZERO
+#undef RUY_STACK_OFFSET_SIZE
+#undef RUY_STACK_OFFSET_DST_COL_PTR
+#undef RUY_STACK_OFFSET_DST_PTR
+#undef RUY_STACK_OFFSET_ROW
+#undef RUY_STACK_OFFSET_COL
+#undef RUY_STACK_OFFSET_LHS_COL_PTR
+#undef RUY_STACK_OFFSET_RHS_COL_PTR
+
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+
+#define RUY_OFFSET_BIAS 0
+#define RUY_OFFSET_LHS_SUMS 4
+#define RUY_OFFSET_RHS_SUMS 8
+#define RUY_OFFSET_LHS_BASE_PTR 12
+#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
+#define RUY_OFFSET_MULTIPLIER_EXPONENT 20
+#define RUY_OFFSET_RHS_BASE_PTR 24
+#define RUY_OFFSET_DST_BASE_PTR 28
+#define RUY_OFFSET_LHS_ZERO_POINT 32
+#define RUY_OFFSET_RHS_ZERO_POINT 36
+#define RUY_OFFSET_DST_ZERO_POINT 40
+#define RUY_OFFSET_PROD_ZP_DEPTH 44
+#define RUY_OFFSET_START_ROW 48
+#define RUY_OFFSET_START_COL 52
+#define RUY_OFFSET_LAST_ROW 56
+#define RUY_OFFSET_LAST_COL 60
+#define RUY_OFFSET_DST_ROWS 64
+#define RUY_OFFSET_DST_COLS 68
+#define RUY_OFFSET_LHS_STRIDE 72
+#define RUY_OFFSET_RHS_STRIDE 76
+#define RUY_OFFSET_DST_STRIDE 80
+#define RUY_OFFSET_DEPTH 84
+#define RUY_OFFSET_CLAMP_MIN 88
+#define RUY_OFFSET_CLAMP_MAX 92
+#define RUY_OFFSET_FLAGS 96
+#define RUY_OFFSET_DST_TYPE_ID 97
+
+#define RUY_STACK_OFFSET_SIZE 96
+#define RUY_STACK_OFFSET_DST_COL_PTR 0
+#define RUY_STACK_OFFSET_DST_PTR 16
+#define RUY_STACK_OFFSET_ROW 32
+#define RUY_STACK_OFFSET_COL 48
+#define RUY_STACK_OFFSET_LHS_COL_PTR 64
+#define RUY_STACK_OFFSET_RHS_COL_PTR 80
+
+template <typename Params>
+void CheckOffsetsInKernelParams8bit(const Params&) {
+ static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
+ "");
+ static_assert(offsetof(Params, multiplier_fixedpoint) ==
+ RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
+ "");
+ static_assert(
+ offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
+ "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
+ static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+}
+
+// Fast-int8 kernel, ported from ARM 64 version.
+// Relevant target CPUs for this kernel include Krait 400 and A9,
+// since these are 32-bit, out-of-order CPUs.
+void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // q6 - q13 are 128-bit (4x32b) accumulators.
+ // During accumulation, d0 -- d7 are used to load int8 data from LHS and
+ // d8 -- d11 from RHS:
+ // int8 RHS 16x2 block
+ // /-----------------------------\
+ // |d8.b[0-7] ..... d10.b[0-7]|
+ // | ... ... |
+ // |d9.b[0-7] ..... d11.b[0-7]|
+ // \-----------------------------/
+ // int8 LHS 4x16 block
+ // /------------------------\ /-----------------------------\
+ // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
+ // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
+ // (Reload d0, d1, d2, d3)
+ // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
+ // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
+ // \------------------------/ \-----------------------------/
+ // 128-bit accumulators 4x2 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
+
+ // clang-format off
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ RUY_MAKE_ZERO(q11)
+ "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(q12)
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(q13)
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(q14)
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ RUY_MAKE_ZERO(q15)
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // r1 is how many levels of depth we have already loaded
+ // data for, r10 is the total depth.
+ "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r10\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q6, q7, q10, q11
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+ "vpadal.s16 q10, q2\n"
+ "vpadal.s16 q11, q3\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q8, q9, q12, q13
+ "vpadal.s16 q8, q14\n"
+ "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
+ "vpadal.s16 q9, q15\n"
+ "vpadal.s16 q12, q2\n"
+ "vpadal.s16 q13, q3\n"
+
+ // Prefetch the next 64 bytes of LHS and RHS data.
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add r1, r1, #16\n"
+
+ // Loop termination condition
+ "cmp r1, r10\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
+
+ // Then pairwise accumulate in to q6, q7, q10, q11
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+ "vpadal.s16 q10, q2\n"
+ "vpadal.s16 q11, q3\n"
+
+ // Mult, mult-acc in to q14, q15, q2, q3
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q2, d0, d10\n"
+
+ "vmull.s8 q15, d2, d8\n"
+ "vmull.s8 q3, d2, d10\n"
+
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q2, d1, d11\n"
+ "vmlal.s8 q15, d3, d9\n"
+ "vmlal.s8 q3, d3, d11\n"
+
+ // Then pairwise accumulate in to q8, q9, q12, q13
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+ "vpadal.s16 q12, q2\n"
+ "vpadal.s16 q13, q3\n"
+
+
+ // All accumulation over depth done. q6 - q13 contain the 4x32b
+ // accumulators for the 4x2 final matrix.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x2 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // q6-q13 now contain 4 x 32b
+ "vpadd.i32 d0, d12, d13\n"
+ "vpadd.i32 d1, d14, d15\n"
+ "vpadd.i32 d2, d16, d17\n"
+ "vpadd.i32 d3, d18, d19\n"
+ "vpadd.i32 d4, d20, d21\n"
+ "vpadd.i32 d5, d22, d23\n"
+ "vpadd.i32 d6, d24, d25\n"
+ "vpadd.i32 d7, d26, d27\n"
+
+ // d0-d7 each contain 2 x 32b accumulators.
+ // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
+ // of destination, (Four 'd' registers total)
+ "vpadd.i32 d28, d0, d1\n"
+ "vpadd.i32 d29, d2, d3\n"
+ "vpadd.i32 d30, d4, d5\n"
+ "vpadd.i32 d31, d6, d7\n"
+
+ //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #1\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "add r5, r1, r8, lsl #2\n"
+
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "movne r1, r5\n"
+
+ // Load 4 bias values.
+ "vld1.32 {d24, d25}, [r1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Add to the bias values the product
+ // (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in
+ // https://arxiv.org/pdf/1712.05877.pdf
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "vdup.32 q9, r3\n"
+ "vadd.i32 q12, q12, q9\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "vadd.i32 q14, q14, q12\n"
+ "vadd.i32 q15, q15, q12\n"
+
+ // LHS/RHS zero points
+ // Has RHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ // Offset by current col * number of bytes per value
+ "add r3, r3, r4, lsl #2\n"
+ "vld1.32 { d12 }, [r3]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "vdup.32 q10, r5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vmls.i32 q14, q10, d12[0]\n"
+ "vmls.i32 q15, q10, d12[1]\n"
+ "401:\n"
+
+ // Has LHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Offset by current row * number of bytes per value
+ "add r2, r2, r4, lsl #2\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+
+ // Load 4 lhs_sums values.
+ "vld1.32 {d22, d23}, [r2]\n"
+ "vdup.32 d13, r5\n" // rhs_zero_point
+
+ // Compute lhs_sums * rhs_zero_point.
+ "vmul.i32 q11, q11, d13[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vsub.s32 q14, q14, q11\n"
+ "vsub.s32 q15, q15, q11\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "add r5, r1, r4, lsl #2\n"
+ "it ne\n"
+ "movne r1, r5\n"
+
+ "vld1.32 {q10}, [r1]\n"
+
+ RUY_MAKE_ZERO(q8)
+ "vmax.s32 q12, q10, q8\n"
+
+ "vshl.s32 q14, q14, q12\n"
+ "vshl.s32 q15, q15, q12\n"
+
+ "vmin.s32 q12, q10, q8\n"
+
+ // Load fixed point part of the multiplier
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ // r6 has flags, r4 has row
+ "add r5, r1, r4, lsl #2\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "it ne\n"
+ "movne r1, r5\n"
+ "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
+
+ // Apply the fixed-point part of the multiplier.
+ "vqrdmulh.s32 q14, q14, q10\n"
+ "vqrdmulh.s32 q15, q15, q10\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "vand q8, q14, q12\n"
+ "vand q9, q15, q12\n"
+ "vshr.s32 q8, q8, #31\n"
+ "vshr.s32 q9, q9, #31\n"
+ "vqadd.s32 q14, q14, q8\n"
+ "vqadd.s34 q15, q15, q9\n"
+
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "vrshl.s32 q14, q14, q12\n"
+ "vrshl.s32 q15, q15, q12\n"
+
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ // Store uint8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, d12 -- d26, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to uint8
+ // Now all 8 1-byte values are in d30.
+ "vqmovun.s16 d30, q14\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.u8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.u8 d30, d30, d29\n"
+
+ // Compute how much of the 4x2 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #4\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.32 {d30[0]}, [r3]\n"
+ "add r4, r4, r5\n"
+ "mov r3, r4\n"
+ "vst1.32 {d30[1]}, [r3]\n"
+
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ // Store int8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, d12 -- d26, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to int8
+ // Now all 8 1-byte values are in d30.
+ "vqmovn.s16 d30, q14\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.s8 d30, d30, d29\n"
+
+ // Compute how much of the 4x2 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #4\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.32 {d30[0]}, [r3]\n"
+ "add r4, r4, r5\n"
+ "mov r3, r4\n"
+ "vst1.32 {d30[1]}, [r3]\n"
+
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Load the destination zero point into each of the 4 32-bit slots
+ // in a q register.
+ "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.32 q13, r4\n" // dst_zero_point
+ // Add the destination zero point
+ "vadd.s32 q14, q14, q13\n"
+ "vadd.s32 q15, q15, q13\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in q14.
+ "vqmovn.s32 d28, q14\n"
+ "vqmovn.s32 d29, q15\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.16 q12, r2\n" // clamp_min
+ "vdup.16 q13, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s16 q14, q14, q12\n"
+ // Apply the clamp_max bound
+ "vmin.s16 q14, q14, q13\n"
+
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x2 block of destination 16-bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.16 {q14}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ // Shift of offset register for half-word loads not allowed in A32,
+ // so we shift, load/store, then shift back r8.
+ "lsl r8, r8, #1\n"
+ "ldrh r10, [r3, r8]\n"
+ "strh r10, [r4, r8]\n"
+ "lsr r8, r8, #1\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #8\n"
+ "add r4, r4, r5\n"
+ "cmp r6, r2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #2\n"
+
+ "vst1.16 {d28[0]}, [r3], r6\n"
+ "add r4, r4, r5\n"
+ "vst1.16 {d28[1]}, [r3], r6\n"
+ "vst1.16 {d28[2]}, [r3], r6\n"
+ "vst1.16 {d28[3]}, [r3], r6\n"
+ "mov r3, r4\n"
+ "vst1.16 {d29[0]}, [r3], r6\n"
+ "vst1.16 {d29[1]}, [r3], r6\n"
+ "vst1.16 {d29[2]}, [r3], r6\n"
+ "vst1.16 {d29[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #8\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q14)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x2 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ "cmp r2, #2\n"
+ // Compute r2 = how many cols of the 4x2 block fit
+ "it gt\n"
+ "movgt r2, r5\n"
+
+ // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
+ "cmp r1, r3\n"
+ "it eq\n"
+ "cmpeq r2, r5\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #16\n"
+ "b 31f\n"
+
+ "30:\n"
+ // Yes, all of the 4x2 block fits.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // r3 address, r4 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+
+ "31:\n"
+
+ "vst1.32 {d28, d29}, [r3]\n"
+ "add r3, r3, r4\n"
+ "vst1.32 {d30, d31}, [r3]\n"
+
+ // If all of the 4x2 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 4x2 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r6, #0\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+ "add r6, r6, #1\n"
+ "add r3, r3, #16\n"
+ "add r4, r4, r8\n"
+ // r2 = how many cols of the 8x4 block fit
+ "cmp r6, r2\n"
+ "blt 50b\n"
+
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #16\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #4\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
+ "add r1, r1, r8, lsl #1\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13", "q14", "q15");
+}
+
+// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
+// is still packed as if it has two columns
+void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // q6 - q13 are 128-bit (4x32b) accumulators.
+ // During accumulation, d0 -- d7 are used to load int8 data from LHS and
+ // d8 -- d11 from RHS:
+ // int8 RHS 16x1 block
+ // /------------\
+ // | d8.b[0] |
+ // | ... |
+ // | d8.b[7] |
+ // | d9.b[0] |
+ // | ... |
+ // | d9.b[7] |
+ // \------------/
+ // int8 LHS 4x16 block
+ // /-----------------------------------------\ /------------\
+ // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
+ // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
+ // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
+ // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
+ // \-----------------------------------------/ \------------/
+ // 128-bit accumulators 4x1 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
+
+ // clang-format off
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+
+ "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
+ "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
+ "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ // r1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // r1 is how many levels of depth we have already loaded
+ // data for, r10 is the total depth.
+ "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ "cmp r1, r10\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q15, d2, d8\n"
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q15, d3, d9\n"
+
+ // Then pairwise accumulate in to q6, q7
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d4, d8\n"
+ "vmull.s8 q15, d6, d8\n"
+ "vmlal.s8 q14, d5, d9\n"
+ "vmlal.s8 q15, d7, d9\n"
+
+ // Then pairwise accumulate in to q8, q9
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+
+
+ // Load the next 64 bytes of LHS and RHS data.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add r1, r1, #16\n"
+
+ // Loop termination condition
+ "cmp r1, r10\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d0, d8\n"
+ "vmull.s8 q15, d2, d8\n"
+ "vmlal.s8 q14, d1, d9\n"
+ "vmlal.s8 q15, d3, d9\n"
+
+ // Then pairwise accumulate in to q6, q7
+ "vpadal.s16 q6, q14\n"
+ "vpadal.s16 q7, q15\n"
+
+ // Mult, mult-acc in to q14, q15
+ "vmull.s8 q14, d4, d8\n"
+ "vmull.s8 q15, d6, d8\n"
+ "vmlal.s8 q14, d5, d9\n"
+ "vmlal.s8 q15, d7, d9\n"
+
+ // Then pairwise accumulate in to q8, q9
+ "vpadal.s16 q8, q14\n"
+ "vpadal.s16 q9, q15\n"
+
+ // All accumulation over depth done. q6 - q9 contain the 4x32b
+ // accumulators for the 4x1 final matrix.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x2 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // q6-q9 now contain 4 x 32b
+ "vpadd.i32 d0, d12, d13\n"
+ "vpadd.i32 d1, d14, d15\n"
+ "vpadd.i32 d2, d16, d17\n"
+ "vpadd.i32 d3, d18, d19\n"
+
+ // d0-d4 each contain 2 x 32b accumulators.
+ // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
+ // of destination, (Four 'd' registers total)
+ "vpadd.i32 d28, d0, d1\n"
+ "vpadd.i32 d29, d2, d3\n"
+
+ // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r1, r3\n" // Have we finished the last row?
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "add r4, r4, r1, lsl #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ // Go back to first row
+ "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "add r10, r10, r1, lsl #1\n"
+ "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
+ "mov %[lhs_ptr], r4\n"
+ "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
+ "mov %[rhs_ptr], r5\n"
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "add r5, r1, r8, lsl #2\n"
+
+ "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "it ne\n"
+ "movne r1, r5\n"
+
+ // Load 4 bias values.
+ "vld1.32 {d24, d25}, [r1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
+ "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
+ RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
+ "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
+ // Skip the other column and advance the pointer.
+ "add %[rhs_ptr], %[rhs_ptr], #16\n"
+ RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
+
+ // Add to the bias values the product
+ // (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in
+ // https://arxiv.org/pdf/1712.05877.pdf
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "vdup.32 q9, r3\n"
+ "vadd.i32 q12, q12, q9\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "vadd.i32 q14, q14, q12\n"
+
+ // LHS/RHS zero points
+ // Has RHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ // Offset by current col * number of bytes per value
+ "add r3, r3, r4, lsl #2\n"
+ "vld1.32 { d12 }, [r3]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "vdup.32 q10, r5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vmls.i32 q14, q10, d12[0]\n"
+ "401:\n"
+
+ // Has LHS sums
+ "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Offset by current row * number of bytes per value
+ "add r2, r2, r4, lsl #2\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+
+ // Load 4 lhs_sums values.
+ "vld1.32 {d22, d23}, [r2]\n"
+ "vdup.32 d13, r5\n" // rhs_zero_point
+
+ // Compute lhs_sums * rhs_zero_point.
+ "vmul.i32 q11, q11, d13[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "vsub.s32 q14, q14, q11\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "add r5, r1, r4, lsl #2\n"
+ "it ne\n"
+ "movne r1, r5\n"
+
+ "vld1.32 {q10}, [r1]\n"
+
+ RUY_MAKE_ZERO(q8)
+ "vmax.s32 q12, q10, q8\n"
+
+ "vshl.s32 q14, q14, q12\n"
+
+ "vmin.s32 q12, q10, q8\n"
+
+ // Load fixed point part of the multiplier
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ // r6 has flags, r4 has row
+ "add r5, r1, r4, lsl #2\n"
+ "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "it ne\n"
+ "movne r1, r5\n"
+ "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
+
+ // Apply the fixed-point part of the multiplier.
+ "vqrdmulh.s32 q14, q14, q10\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "vand q8, q14, q12\n"
+ "vshr.s32 q8, q8, #31\n"
+ "vqadd.s32 q14, q14, q8\n"
+
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "vrshl.s32 q14, q14, q12\n"
+
+ "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ // Store uint8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "vqmovun.s16 d30, q14\n"
+ // At this point, we only need 4 8-bit values in the lower half
+ // of d30.
+
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.u8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.u8 d30, d30, d29\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x1 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.8 {d30[0]}, [r3], r6\n"
+ "vst1.8 {d30[1]}, [r3], r6\n"
+ "vst1.8 {d30[2]}, [r3], r6\n"
+ "vst1.8 {d30[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ // Store int8 values:
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // current block, so we can start clearing these accumulators for the
+ // next block (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the destination zero point into each of the 8 16-bit slots
+ // in a q register.
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.16 q13, r4\n" // dst_zero_point
+
+ // Add the destination zero point
+ "vadd.i16 q14, q14, q13\n"
+
+ // Cast-and-saturate from int16 to int8
+ "vqmovn.s16 d30, q14\n"
+ // At this point, we only need 4 8-bit values in the lower half
+ // of d30.
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.8 d28, r2\n" // clamp_min
+ "vdup.8 d29, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s8 d30, d30, d28\n"
+ // Apply the clamp_max bound
+ "vmin.s8 d30, d30, d29\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x2 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4 i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x2 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x2 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.8 {d30}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ "ldrb r10, [r3, r8]\n"
+ "strb r10, [r4, r8]\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #1\n"
+
+ "vst1.8 {d30[0]}, [r3], r6\n"
+ "vst1.8 {d30[1]}, [r3], r6\n"
+ "vst1.8 {d30[2]}, [r3], r6\n"
+ "vst1.8 {d30[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #4\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q13)
+ RUY_MAKE_ZERO(q14)
+ RUY_MAKE_ZERO(q15)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Load the destination zero point into each of the 4 32-bit slots
+ // in a q register.
+ "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "vdup.32 q13, r4\n" // dst_zero_point
+ // Add the destination zero point
+ "vadd.s32 q14, q14, q13\n"
+ //"vadd.s32 q15, q15, q13\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in d28.
+ "vqmovn.s32 d28, q14\n"
+
+ // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q15)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "vdup.16 d24, r2\n" // clamp_min
+ "vdup.16 d26, r3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "vmax.s16 d28, d28, d24\n"
+ // Apply the clamp_max bound
+ "vmin.s16 d28, d28, d26\n"
+
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x1 block of destination 16-bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x1 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ // Set r3 address to write to dst_tmp_buf.
+ "mov r3, %[dst_tmp_buf]\n"
+ "vst1.16 {d28}, [r3]\n"
+
+ // Slow loop copying from dst_tmp_buf to dst.
+ "50:\n"
+ "mov r8, #0\n"
+ "51:\n"
+ // Shift of offset register for half-word loads not allowed in A32,
+ // so we shift, load/store, then shift back r8.
+ "lsl r8, r8, #1\n"
+ "ldrh r10, [r3, r8]\n"
+ "strh r10, [r4, r8]\n"
+ "lsr r8, r8, #1\n"
+ "add r8, r8, #1\n"
+ "cmp r8, r1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ // r3 address, r5 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r3\n"
+ "mov r6, #2\n"
+
+ "vst1.16 {d28[0]}, [r3], r6\n"
+ "vst1.16 {d28[1]}, [r3], r6\n"
+ "vst1.16 {d28[2]}, [r3], r6\n"
+ "vst1.16 {d28[3]}, [r3], r6\n"
+ "31:\n"
+
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #8\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q14)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ // Clear accumulators.
+ RUY_MAKE_ZERO(q6)
+ RUY_MAKE_ZERO(q7)
+ RUY_MAKE_ZERO(q8)
+ RUY_MAKE_ZERO(q9)
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+ RUY_MAKE_ZERO(q12)
+ RUY_MAKE_ZERO(q13)
+
+ // Compute how much of the 4x1 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x2, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+
+ "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "sub r1, r1, r8\n"
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "sub r2, r2, r4\n"
+ "mov r3, #4\n"
+ "mov r5, #2\n"
+ "cmp r1, #4\n"
+ // Compute r1 = how many rows of the 4x2 block fit
+ "it gt\n"
+ "movgt r1, r3\n"
+
+ // Test if r1==4, i.e. if all of the 4x1 block fits.
+ "cmp r1, r3\n"
+
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Set (r3 address, r4 stride) to write to dst_tmp_buf
+ "mov r3, %[dst_tmp_buf]\n"
+ "mov r4, #16\n"
+ "b 31f\n"
+
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ // r3 address, r4 stride
+ "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "mov r4, r5\n"
+
+ "31:\n"
+
+ "vst1.32 {d28, d29}, [r3]\n"
+
+ // If all of the 4x1 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 4x1 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "mov r3, %[dst_tmp_buf]\n"
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "50:\n"
+ "mov r5, #0\n"
+ "51:\n"
+ "ldr r10, [r3, r5, lsl #2]\n"
+ "str r10, [r4, r5, lsl #2]\n"
+ "add r5, r5, #1\n"
+ "cmp r5, r1\n"
+ "blt 51b\n"
+
+ "41:\n"
+ // Load dst_ptr, increment, and write back.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "add r4, r4, #16\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+
+ RUY_MAKE_ZERO(q10)
+ RUY_MAKE_ZERO(q11)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ "cmp r8, r3\n"
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add r8, r8, #4\n"
+ // Store new value of row
+ "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ // Move back to first row.
+ "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
+ // Move to the next column.
+ "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "add r4, r4, #2\n"
+ "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+
+ "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Increment dst_col_ptr by dst_stride (i.e. 1 column)
+ "add r1, r1, r8\n"
+ // Store dst_col_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
+ // Store dst_ptr
+ "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
+ "cmp r8, r4\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov r1, #16\n"
+
+ "ble 1b\n"
+
+ // Restore stack pointer.
+ "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
+
+ // clang-format on
+
+ : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
+ : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
+ // Clobber list must specify q registers (and not their constituent
+ // d registers). There is a (currently unexplained) slowdown if
+ // d registers are listed in the clobbers list.
+ "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
+ "q9", "q10", "q12", "q13", "q14", "q15");
+}
+
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_LHS_SUMS
+#undef RUY_OFFSET_RHS_SUMS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
+#undef RUY_OFFSET_MULTIPLIER_EXPONENT
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_LHS_ZERO_POINT
+#undef RUY_OFFSET_RHS_ZERO_POINT
+#undef RUY_OFFSET_DST_ZERO_POINT
+#undef RUY_OFFSET_PROD_ZP_DEPTH
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+#undef RUY_OFFSET_DST_TYPE_ID
+
+#undef RUY_STACK_OFFSET_SIZE
+#undef RUY_STACK_OFFSET_DST_COL_PTR
+#undef RUY_STACK_OFFSET_DST_PTR
+#undef RUY_STACK_OFFSET_ROW
+#undef RUY_STACK_OFFSET_COL
+#undef RUY_STACK_OFFSET_LHS_COL_PTR
+#undef RUY_STACK_OFFSET_RHS_COL_PTR
+
+#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM)
+} // namespace ruy
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
new file mode 100644
index 0000000..38af032
--- /dev/null
+++ b/ruy/kernel_arm64.cc
@@ -0,0 +1,7835 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+
+#include "ruy/common.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#define RUY_ASM_LABEL_STORE_UINT8 91
+#define RUY_ASM_LABEL_STORE_INT8 92
+#define RUY_ASM_LABEL_STORE_INT16 93
+#define RUY_ASM_LABEL_STORE_INT32 94
+#define RUY_ASM_LABEL_AFTER_STORE 99
+
+#define RUY_OFFSET_BIAS 0
+#define RUY_OFFSET_LHS_SUMS 8
+#define RUY_OFFSET_RHS_SUMS 16
+#define RUY_OFFSET_LHS_BASE_PTR 24
+#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
+#define RUY_OFFSET_MULTIPLIER_EXPONENT 40
+#define RUY_OFFSET_RHS_BASE_PTR 48
+#define RUY_OFFSET_DST_BASE_PTR 56
+#define RUY_OFFSET_LHS_ZERO_POINT 64
+#define RUY_OFFSET_RHS_ZERO_POINT 68
+#define RUY_OFFSET_DST_ZERO_POINT 72
+#define RUY_OFFSET_PROD_ZP_DEPTH 76
+#define RUY_OFFSET_START_ROW 80
+#define RUY_OFFSET_START_COL 84
+#define RUY_OFFSET_LAST_ROW 88
+#define RUY_OFFSET_LAST_COL 92
+#define RUY_OFFSET_DST_ROWS 96
+#define RUY_OFFSET_DST_COLS 100
+#define RUY_OFFSET_LHS_STRIDE 104
+#define RUY_OFFSET_RHS_STRIDE 108
+#define RUY_OFFSET_DST_STRIDE 112
+#define RUY_OFFSET_DEPTH 116
+#define RUY_OFFSET_CLAMP_MIN 120
+#define RUY_OFFSET_CLAMP_MAX 124
+#define RUY_OFFSET_FLAGS 128
+
+template <typename Params>
+void CheckOffsetsInKernelParams8bit(const Params&) {
+ static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
+ "");
+ static_assert(offsetof(Params, multiplier_fixedpoint) ==
+ RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
+ "");
+ static_assert(
+ offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
+ "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
+ static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+}
+
+// Fast-int8-trick kernel, similar to this production gemmlowp kernel:
+// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
+//
+// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
+// since these are 64-bit, out-of-order and without dotprod support.
+void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 -- v7 from RHS:
+ //
+ // int8 RHS 16x4 block
+ // /-----------------------------------------\
+ // |v4.b[0] ... v7.b[0] |
+ // | ... ... |
+ // |v4.b[15] ... v7.b[15] |
+ // \-----------------------------------------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------------------------------------\
+ // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 4x4 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "sadalp v26.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "sadalp v27.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+
+ // Loop termination condition
+ "cmp w1, w12\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "sadalp v25.4s, v9.8h\n"
+ "sadalp v26.4s, v10.8h\n"
+ "sadalp v27.4s, v11.8h\n"
+ "sadalp v28.4s, v12.8h\n"
+ "sadalp v29.4s, v13.8h\n"
+ "sadalp v30.4s, v14.8h\n"
+ "sadalp v31.4s, v15.8h\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 4x4 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x4 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+ "addp v20.4s, v20.4s, v21.4s\n"
+ "addp v22.4s, v22.4s, v23.4s\n"
+ "addp v24.4s, v24.4s, v25.4s\n"
+ "addp v26.4s, v26.4s, v27.4s\n"
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+ "addp v17.4s, v20.4s, v22.4s\n"
+ "addp v18.4s, v24.4s, v26.4s\n"
+ "addp v19.4s, v28.4s, v30.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ "add x5, x1, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v14.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v14.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[1]\n"
+ "mls v18.4s, v10.4s, v14.s[2]\n"
+ "mls v19.4s, v10.4s, v14.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v11.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smax v12.4s, v14.4s, v8.4s\n"
+
+ "sshl v16.4s, v16.4s, v12.4s\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "sshl v18.4s, v18.4s, v12.4s\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+
+ "smin v12.4s, v14.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v15.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqrdmulh v18.4s, v18.4s, v15.4s\n"
+ "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v12.16b\n"
+ "and v9.16b, v17.16b, v12.16b\n"
+ "and v14.16b, v18.16b, v12.16b\n"
+ "and v15.16b, v19.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+ "sqadd v17.4s, v17.4s, v9.4s\n"
+ "sqadd v18.4s, v18.4s, v14.4s\n"
+ "sqadd v19.4s, v19.4s, v15.4s\n"
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v12.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v12.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to int8
+ "sqxtn v16.8b, v16.8h\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[5], [x3], #2\n"
+ "st1 {v16.h}[6], [x3], #2\n"
+ "st1 {v16.h}[7], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[1], [x3], #2\n"
+ "st1 {v17.h}[2], [x3], #2\n"
+ "st1 {v17.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[5], [x3], #2\n"
+ "st1 {v17.h}[6], [x3], #2\n"
+ "st1 {v17.h}[7], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ "str q18, [%[dst_tmp_buf], #32]\n"
+ "str q19, [%[dst_tmp_buf], #48]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.s}[1], [x3], #4\n"
+ "st1 {v17.s}[2], [x3], #4\n"
+ "st1 {v17.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v18.s}[1], [x3], #4\n"
+ "st1 {v18.s}[2], [x3], #4\n"
+ "st1 {v18.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v19.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v19.s}[1], [x3], #4\n"
+ "st1 {v19.s}[2], [x3], #4\n"
+ "st1 {v19.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of
+// RHS cols == 1.
+// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
+// since these are 64-bit, out-of-order and without dotprod support.
+void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v19 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 from RHS:
+ //
+ // int8 RHS 16x1 block
+ // /-----------\
+ // |v4.b[0] |
+ // | ... |
+ // |v4.b[15] |
+ // \-----------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------\
+ // |v0.b[0] ... v0.b[15] | |v16.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s |
+ // \---------------------/ \-----------/
+ // int32 accumulators 4x1 block
+ //
+ // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
+ // optimization for this kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+ "sadalp v17.4s, v9.8h\n"
+ "sadalp v18.4s, v10.8h\n"
+ "sadalp v19.4s, v11.8h\n"
+
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+
+ // Loop termination condition
+ "cmp w1, w12\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "sadalp v17.4s, v9.8h\n"
+ "sadalp v18.4s, v10.8h\n"
+ "sadalp v19.4s, v11.8h\n"
+
+ // End of accumulation. The registers v16 -- v19 contain the final
+ // int32 accumulator values of the current 4x1 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x1 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ // (still multiply column stride by 4 due to packing)
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ // Now we load: bias data, LHS sums data, RHS sums data.
+
+ // First, load the base pointers from the params.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ "add x5, x1, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ "add %[rhs_ptr], %[rhs_ptr], #48\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // (all four 32-bit accumulators are in v16 at this point)
+ "add v16.4s, v16.4s, v14.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smax v12.4s, v14.4s, v8.4s\n"
+
+ "sshl v16.4s, v16.4s, v12.4s\n"
+
+ "smin v12.4s, v14.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this instruction, all data is in lower half (64-bits) of v16
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ // Now all data is in the first 32-bits of v16
+ "sqxtun v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x1, there are some 4x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x1 block fit
+ "csel w1, w1, w3, le\n"
+
+ // Test if w1==4, i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x1 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x1 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this, all values for output are in the lower half (64 bits) of v16.
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to int8
+ "sqxtn v16.8b, v16.8h\n"
+ // At this point, we only need 4 lowest 8-bit values in v16.
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4, i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ // After this instruction, all data is in lower half of v16.
+ "sqxtn v16.4h, v16.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+
+ // Test if w1==4 i.e. if all of the 4x1 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19");
+}
+
+// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs.
+// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
+// the original Cortex-A55, since these are 64-bit and do not support dotprod.
+//
+// While this kernel does not have a direct equivalent in gemmlowp, it was
+// developed based on insights that David Mansell at ARM shared with their
+// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A53:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
+void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // v4 -- v7 from RHS:
+ //
+ // int8 RHS 16x4 block
+ // /-----------------------------------------\
+ // |v4.b[0] ... v7.b[0] |
+ // | ... ... |
+ // |v4.b[15] ... v7.b[15] |
+ // \-----------------------------------------/
+ // int8 LHS 4x16 block
+ // /---------------------\ /-----------------------------------------\
+ // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
+ // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
+ // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
+ // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 4x4 block
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(v16)
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(v17)
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ RUY_MAKE_ZERO(v18)
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ RUY_MAKE_ZERO(v19)
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v20)
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v21)
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v22)
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+ RUY_MAKE_ZERO(v23)
+
+ // Load the first 64 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 16.
+ "mov w1, #16\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Some multiplications and 16-bit accumulation were already done above,
+ // so we start right away in the middle.
+ "sadalp v16.4s, v8.8h\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "ldr x7, [%[rhs_ptr], #8]\n"
+ "sadalp v17.4s, v9.8h\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "ldr x8, [%[rhs_ptr], #24]\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "sadalp v20.4s, v12.8h\n"
+ // Each iteration of this loop advances by 16 levels of depth.
+ "add w1, w1, #16\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ // Loop termination condition
+ "cmp w1, w12\n"
+ "sadalp v21.4s, v13.8h\n"
+ "ldr x3, [%[lhs_ptr], #-56]\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "ldr x4, [%[lhs_ptr], #-40]\n"
+ "sadalp v22.4s, v14.8h\n"
+ "ldr x5, [%[lhs_ptr], #-24]\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "ldr x6, [%[lhs_ptr], #-8]\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "ldr x9, [%[rhs_ptr], #-24]\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+ "ldr d6, [%[rhs_ptr], #-32]\n"
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "ldr d0, [%[lhs_ptr], #-64]\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "ldr d1, [%[lhs_ptr], #-48]\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "ins v4.d[1], x7\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+ "ins v5.d[1], x8\n"
+
+ "ldr d2, [%[lhs_ptr], #-32]\n"
+ "ins v0.d[1], x3\n"
+ "sadalp v24.4s, v8.8h\n"
+ "ldr d3, [%[lhs_ptr], #-16]\n"
+ "ins v1.d[1], x4\n"
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "ins v2.d[1], x5\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ins v3.d[1], x6\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "ldr d7, [%[rhs_ptr], #-16]\n"
+ "sadalp v26.4s, v10.8h\n"
+ "ldr x10, [%[rhs_ptr], #-8]\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ "sadalp v27.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "sadalp v28.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "sadalp v29.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "sadalp v30.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "sadalp v31.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "ins v6.d[1], x9\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "ins v7.d[1], x10\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+
+ "sadalp v16.4s, v8.8h\n"
+ "smull v8.8h, v0.8b, v6.8b\n"
+ "sadalp v17.4s, v9.8h\n"
+ "smull v9.8h, v1.8b, v6.8b\n"
+ "sadalp v18.4s, v10.8h\n"
+ "smull v10.8h, v2.8b, v6.8b\n"
+ "sadalp v19.4s, v11.8h\n"
+ "smull v11.8h, v3.8b, v6.8b\n"
+ "sadalp v20.4s, v12.8h\n"
+ "smull v12.8h, v0.8b, v7.8b\n"
+ "sadalp v21.4s, v13.8h\n"
+ "smull v13.8h, v1.8b, v7.8b\n"
+ "sadalp v22.4s, v14.8h\n"
+ "smull v14.8h, v2.8b, v7.8b\n"
+ "sadalp v23.4s, v15.8h\n"
+ "smull v15.8h, v3.8b, v7.8b\n"
+
+ // Multiply-accumulate second-half, again into the same
+ // 16bit local accumulator registers. This is where we
+ // take advantage of having int8 instead of uint8 and therefore
+ // being able to accumulate two products into int16.
+ "smlal2 v8.8h, v0.16b, v6.16b\n"
+ "smlal2 v9.8h, v1.16b, v6.16b\n"
+ "smlal2 v10.8h, v2.16b, v6.16b\n"
+ "smlal2 v11.8h, v3.16b, v6.16b\n"
+
+ "smlal2 v12.8h, v0.16b, v7.16b\n"
+ "smlal2 v13.8h, v1.16b, v7.16b\n"
+ "smlal2 v14.8h, v2.16b, v7.16b\n"
+ "smlal2 v15.8h, v3.16b, v7.16b\n"
+
+ "sadalp v24.4s, v8.8h\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "sadalp v25.4s, v9.8h\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "sadalp v26.4s, v10.8h\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "sadalp v27.4s, v11.8h\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "sadalp v28.4s, v12.8h\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "sadalp v29.4s, v13.8h\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "sadalp v30.4s, v14.8h\n"
+ "sadalp v31.4s, v15.8h\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 4x4 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 4x4 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Reduce 32bit accumulators horizontally.
+ "addp v16.4s, v16.4s, v17.4s\n"
+ "addp v18.4s, v18.4s, v19.4s\n"
+ "addp v20.4s, v20.4s, v21.4s\n"
+ "addp v22.4s, v22.4s, v23.4s\n"
+ "addp v24.4s, v24.4s, v25.4s\n"
+ "addp v26.4s, v26.4s, v27.4s\n"
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+
+ // Reduce 32bit accumulators horizontally, second pass
+ // (each pass adds pairwise. we need to add 4-wise).
+ "addp v16.4s, v16.4s, v18.4s\n"
+ "addp v17.4s, v20.4s, v22.4s\n"
+ "addp v18.4s, v24.4s, v26.4s\n"
+ "addp v19.4s, v28.4s, v30.4s\n"
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 4 bias values.
+ "ld1 {v14.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "ldr d0, [%[lhs_ptr], #0]\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "ldr d1, [%[lhs_ptr], #16]\n"
+ "add v17.4s, v17.4s, v14.4s\n"
+ "ldr d2, [%[lhs_ptr], #32]\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "ldr d3, [%[lhs_ptr], #48]\n"
+ "add v19.4s, v19.4s, v14.4s\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "ldr d6, [%[rhs_ptr], #32]\n"
+ "ldr d7, [%[rhs_ptr], #48]\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[1]\n"
+ "mls v18.4s, v10.4s, v14.s[2]\n"
+ "mls v19.4s, v10.4s, v14.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v11.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v11.4s\n"
+
+ // If the destination is int32, it means the user asks for the raw
+ // accumulators, no need for us to downquantize the value.
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ld1 {v14.4s}, [x1]\n"
+
+ "smax v12.4s, v14.4s, v8.4s\n"
+ "ldr x1, [%[lhs_ptr], #8]\n"
+
+ "sshl v16.4s, v16.4s, v12.4s\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "sshl v18.4s, v18.4s, v12.4s\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+
+ "smin v12.4s, v14.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "ins v0.d[1], x1\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "sqrdmulh v16.4s, v16.4s, v15.4s\n"
+ "ins v1.d[1], x2\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "ins v2.d[1], x3\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+ "sqrdmulh v18.4s, v18.4s, v15.4s\n"
+ "ins v3.d[1], x4\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v12.16b\n"
+ "and v9.16b, v17.16b, v12.16b\n"
+ "and v14.16b, v18.16b, v12.16b\n"
+ "and v15.16b, v19.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+ "sqadd v17.4s, v17.4s, v9.4s\n"
+ "sqadd v18.4s, v18.4s, v14.4s\n"
+ "sqadd v19.4s, v19.4s, v15.4s\n"
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v12.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v12.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "dup v14.8h, v13.h[4]\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v21)
+ "add v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v22)
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ RUY_MAKE_ZERO(v23)
+ "sqxtun2 v16.16b, v17.8h\n"
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.16b, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.16b, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ // Add the destination zero point
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ "dup v14.8h, v13.h[4]\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v21)
+ "add v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v22)
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ RUY_MAKE_ZERO(v23)
+ "sqxtn2 v16.16b, v17.8h\n"
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.16b, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.16b, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "st1 {v16.16b}, [%[dst_tmp_buf]]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #4\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[0], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[1], [x3], #1\n"
+ "st1 {v16.b}[2], [x3], #1\n"
+ "st1 {v16.b}[3], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[4], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[5], [x3], #1\n"
+ "st1 {v16.b}[6], [x3], #1\n"
+ "st1 {v16.b}[7], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[8], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[9], [x3], #1\n"
+ "st1 {v16.b}[10], [x3], #1\n"
+ "st1 {v16.b}[11], [x3], #1\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.b}[12], [x3], #1\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.b}[13], [x3], #1\n"
+ "st1 {v16.b}[14], [x3], #1\n"
+ "st1 {v16.b}[15], [x3], #1\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #4\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.4h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "ins v4.d[1], x1\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "ins v5.d[1], x2\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "ins v6.d[1], x3\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "ins v7.d[1], x4\n"
+ RUY_MAKE_ZERO(v18)
+ "sqxtn2 v17.8h, v19.4s\n"
+
+ // At this point, v18 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v19)
+
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v20)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ RUY_MAKE_ZERO(v25)
+ "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ RUY_MAKE_ZERO(v26)
+ "dup v14.8h, w2\n" // clamp_min
+ RUY_MAKE_ZERO(v27)
+ "dup v15.8h, w3\n" // clamp_max
+ RUY_MAKE_ZERO(v28)
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ RUY_MAKE_ZERO(v29)
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[1], [x3], #2\n"
+ "st1 {v16.h}[2], [x3], #2\n"
+ "st1 {v16.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.h}[5], [x3], #2\n"
+ "st1 {v16.h}[6], [x3], #2\n"
+ "st1 {v16.h}[7], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[0], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[1], [x3], #2\n"
+ "st1 {v17.h}[2], [x3], #2\n"
+ "st1 {v17.h}[3], [x3], #2\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.h}[4], [x3], #2\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.h}[5], [x3], #2\n"
+ "st1 {v17.h}[6], [x3], #2\n"
+ "st1 {v17.h}[7], [x3], #2\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ "ldr x1, [%[lhs_ptr], #8]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+
+ "ins v0.d[1], x1\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "ins v1.d[1], x2\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "ins v2.d[1], x3\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+ "ins v3.d[1], x4\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "ins v4.d[1], x1\n"
+ "ins v5.d[1], x2\n"
+ "ins v6.d[1], x3\n"
+ "ins v7.d[1], x4\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // At this point, v20 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+
+ RUY_MAKE_ZERO(v20)
+ "add %[lhs_ptr], %[lhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v21)
+ "add %[rhs_ptr], %[rhs_ptr], #64\n"
+ RUY_MAKE_ZERO(v22)
+
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+
+ // Compute how much of the 4x4 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 4x4, there are some 4x4 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ RUY_MAKE_ZERO(v31)
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #4\n"
+ "cmp w1, #4\n"
+ // Compute w1 = how many rows of the 4x4 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #4\n"
+ // Compute w2 = how many cols of the 4x4 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ "mov x4, %[dst_ptr]\n"
+ // Yes, all of the 4x4 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 4x4 block fits.
+ // Store to dst_tmp_buf
+ "str q16, [%[dst_tmp_buf], #0]\n"
+ "str q17, [%[dst_tmp_buf], #16]\n"
+ "str q18, [%[dst_tmp_buf], #32]\n"
+ "str q19, [%[dst_tmp_buf], #48]\n"
+ // Slow loop copying from dst_tmp_buf to dst.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 4x4 block fits.
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v16.s}[1], [x3], #4\n"
+ "st1 {v16.s}[2], [x3], #4\n"
+ "st1 {v16.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v17.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v17.s}[1], [x3], #4\n"
+ "st1 {v17.s}[2], [x3], #4\n"
+ "st1 {v17.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v18.s}[1], [x3], #4\n"
+ "st1 {v18.s}[2], [x3], #4\n"
+ "st1 {v18.s}[3], [x3], #4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v19.s}[0], [x3], #4\n"
+ "add x4, x4, x11\n"
+ "st1 {v19.s}[1], [x3], #4\n"
+ "st1 {v19.s}[2], [x3], #4\n"
+ "st1 {v19.s}[3], [x3], #4\n"
+ "31:\n"
+
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ "smull v8.8h, v0.8b, v4.8b\n"
+ "smull v9.8h, v1.8b, v4.8b\n"
+ "smull v10.8h, v2.8b, v4.8b\n"
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "smull v11.8h, v3.8b, v4.8b\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "smull v12.8h, v0.8b, v5.8b\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "smull v13.8h, v1.8b, v5.8b\n"
+ "smull v14.8h, v2.8b, v5.8b\n"
+ "smull v15.8h, v3.8b, v5.8b\n"
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "smlal2 v8.8h, v0.16b, v4.16b\n"
+ "smlal2 v9.8h, v1.16b, v4.16b\n"
+ "smlal2 v10.8h, v2.16b, v4.16b\n"
+ "smlal2 v11.8h, v3.16b, v4.16b\n"
+ "smlal2 v12.8h, v0.16b, v5.16b\n"
+ "smlal2 v13.8h, v1.16b, v5.16b\n"
+ "smlal2 v14.8h, v2.16b, v5.16b\n"
+ "smlal2 v15.8h, v3.16b, v5.16b\n"
+
+
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #4\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #4\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #16\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params),[dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Kernel taking advantage of the optional dotprod instruction.
+// This is very similar to (and directly inspired by) this gemmlowp kernel
+// which was contributed by David Mansell at ARM:
+// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
+//
+// Besides the ruy-ification, the main difference here is that we use a 8x8
+// instead of 12x8 width, so as to stick to power-of-two widths. This slightly
+// narrower kernel layout is still wide enough to achieve high performance
+// although we haven't actually performed a real comparison to know exactly
+// how this compares to ARM's aforementioned kernel.
+//
+// Relevant target CPUs for this kernel include ARM Cortex-A76,
+// since these are 64-bit, out-of-order and with dotprod support.
+void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v15 are used to load int8 data from LHS and
+ // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
+ // v3 are used to load a 4x8 block of RHS, like this:
+ //
+ // int8 RHS 4x8 block
+ // /-----------------------------------------\
+ // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
+ // | ... ... |
+ // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
+ // \-----------------------------------------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /-----------------------------------------\
+ // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
+ // | ... ... | | ... ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
+ // | ... ... | | ... ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 8x8 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for loading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Optional, maximally-streaming, partial-unrolling (4x unrolled)
+ // optimization of the kernel inner loop (over depth). For more
+ // comments, see the non-unrolled loop below after the #endif.
+#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+ "cmp w12, #32\n"
+ "blt 78f\n"
+
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
+ "mov w1, #16\n"
+
+ "and w3, w12, #-16\n"
+ "81:\n"
+ "add w1, w1, #16\n"
+
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ldr q0, [%[lhs_ptr], #0]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ldr q2, [%[rhs_ptr], #0]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ldr q1, [%[lhs_ptr], #16]\n"
+
+ ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
+ ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
+ "ldr q3, [%[rhs_ptr], #16]\n"
+ ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
+ ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
+ ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
+ ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
+ ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
+ ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
+ ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
+ ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
+ ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
+ ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
+ "ldr q5, [%[lhs_ptr], #48]\n"
+ ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
+ ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
+ "ldr q7, [%[rhs_ptr], #48]\n"
+ ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
+ ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
+ "ldr q4, [%[lhs_ptr], #32]\n"
+
+ ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
+ ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
+ "ldr q6, [%[rhs_ptr], #32]\n"
+ ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
+ ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
+ ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
+ ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
+ ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
+ ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
+ ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
+ ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
+ ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
+ ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
+ "ldr q9, [%[lhs_ptr], #80]\n"
+ ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
+ ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
+ "ldr q11, [%[rhs_ptr], #80]\n"
+ ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
+ ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
+ "ldr q8, [%[lhs_ptr], #64]\n"
+
+ ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
+ ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
+ "ldr q10, [%[rhs_ptr], #64]\n"
+ ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
+ ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #128\n"
+ ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
+ ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #128\n"
+ ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
+ ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
+ ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
+ ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
+ "cmp w1, w3\n"
+ ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
+ ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
+ "ldr q13, [%[lhs_ptr], #-16]\n"
+ ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
+ ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
+ "ldr q15, [%[rhs_ptr], #-16]\n"
+ ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
+ ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
+ "ldr q12, [%[lhs_ptr], #-32]\n"
+
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ "ldr q14, [%[rhs_ptr], #-32]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ "blt 81b\n"
+
+ ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
+ ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
+ ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
+ ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
+ ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
+ ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
+ ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
+ ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
+ ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
+ ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
+ ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
+ ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
+ ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
+ ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
+ ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
+ ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
+
+ ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
+ ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
+ ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
+ ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
+ ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
+ ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
+ ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
+ ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
+ ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
+ ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
+ ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
+ ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
+ ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
+ ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
+ ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
+ ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
+
+ ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
+ ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
+ ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
+ ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
+ ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
+ ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
+ ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
+ ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
+ ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
+ ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
+ ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
+ ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
+ ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
+ ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
+ ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
+ ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
+
+ "78:\n"
+
+#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+
+ // Ordinary kernel inner loop (over depth), the simpler loop that the
+ // above was an equivalent 4x-partially-unrolled version of.
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Because of the data that we have already loaded, we can start the
+ // loop body right away with some multiply-adds.
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ // Each iteration of this loop advances by 4 levels of depth.
+ "add w1, w1, #4\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ // Loop termination condition.
+ "cmp w1, w12\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last 4 levels of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v15.4s\n"
+ "add v20.4s, v20.4s, v14.4s\n"
+ "add v21.4s, v21.4s, v15.4s\n"
+ "add v22.4s, v22.4s, v14.4s\n"
+ "add v23.4s, v23.4s, v15.4s\n"
+ "add v24.4s, v24.4s, v14.4s\n"
+ "add v25.4s, v25.4s, v15.4s\n"
+ "add v26.4s, v26.4s, v14.4s\n"
+ "add v27.4s, v27.4s, v15.4s\n"
+ "add v28.4s, v28.4s, v14.4s\n"
+ "add v29.4s, v29.4s, v15.4s\n"
+ "add v30.4s, v30.4s, v14.4s\n"
+ "add v31.4s, v31.4s, v15.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3], #16\n"
+ "ld1 {v15.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "mls v18.4s, v10.4s, v14.s[1]\n"
+ "mls v19.4s, v10.4s, v14.s[1]\n"
+ "mls v20.4s, v10.4s, v14.s[2]\n"
+ "mls v21.4s, v10.4s, v14.s[2]\n"
+ "mls v22.4s, v10.4s, v14.s[3]\n"
+ "mls v23.4s, v10.4s, v14.s[3]\n"
+ "mls v24.4s, v10.4s, v15.s[0]\n"
+ "mls v25.4s, v10.4s, v15.s[0]\n"
+ "mls v26.4s, v10.4s, v15.s[1]\n"
+ "mls v27.4s, v10.4s, v15.s[1]\n"
+ "mls v28.4s, v10.4s, v15.s[2]\n"
+ "mls v29.4s, v10.4s, v15.s[2]\n"
+ "mls v30.4s, v10.4s, v15.s[3]\n"
+ "mls v31.4s, v10.4s, v15.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2], #16\n"
+ "ld1 {v12.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v12.4s\n"
+ "sub v20.4s, v20.4s, v11.4s\n"
+ "sub v21.4s, v21.4s, v12.4s\n"
+ "sub v22.4s, v22.4s, v11.4s\n"
+ "sub v23.4s, v23.4s, v12.4s\n"
+ "sub v24.4s, v24.4s, v11.4s\n"
+ "sub v25.4s, v25.4s, v12.4s\n"
+ "sub v26.4s, v26.4s, v11.4s\n"
+ "sub v27.4s, v27.4s, v12.4s\n"
+ "sub v28.4s, v28.4s, v11.4s\n"
+ "sub v29.4s, v29.4s, v12.4s\n"
+ "sub v30.4s, v30.4s, v11.4s\n"
+ "sub v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
+ "beq 403f\n"
+ "smax v11.4s, v9.4s, v8.4s\n"
+ "smax v12.4s, v10.4s, v8.4s\n"
+ "sshl v16.4s, v16.4s, v11.4s\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "sshl v18.4s, v18.4s, v11.4s\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+ "sshl v20.4s, v20.4s, v11.4s\n"
+ "sshl v21.4s, v21.4s, v12.4s\n"
+ "sshl v22.4s, v22.4s, v11.4s\n"
+ "sshl v23.4s, v23.4s, v12.4s\n"
+ "sshl v24.4s, v24.4s, v11.4s\n"
+ "sshl v25.4s, v25.4s, v12.4s\n"
+ "sshl v26.4s, v26.4s, v11.4s\n"
+ "sshl v27.4s, v27.4s, v12.4s\n"
+ "sshl v28.4s, v28.4s, v11.4s\n"
+ "sshl v29.4s, v29.4s, v12.4s\n"
+ "sshl v30.4s, v30.4s, v11.4s\n"
+ "sshl v31.4s, v31.4s, v12.4s\n"
+ "403:\n"
+
+ "ldr q14, [x4]\n" // multiplier_fixedpoint
+ "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
+
+ "smin v11.4s, v9.4s, v8.4s\n"
+ "smin v12.4s, v10.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqrdmulh v18.4s, v18.4s, v14.4s\n"
+ "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqrdmulh v20.4s, v20.4s, v14.4s\n"
+ "sqrdmulh v21.4s, v21.4s, v15.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v14.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqrdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqrdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqrdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqrdmulh v31.4s, v31.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v11.16b\n"
+ "and v9.16b, v17.16b, v12.16b\n"
+ "and v14.16b, v18.16b, v11.16b\n"
+ "and v15.16b, v19.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+ "sqadd v17.4s, v17.4s, v9.4s\n"
+ "sqadd v18.4s, v18.4s, v14.4s\n"
+ "sqadd v19.4s, v19.4s, v15.4s\n"
+ "and v8.16b, v20.16b, v11.16b\n"
+ "and v9.16b, v21.16b, v12.16b\n"
+ "and v14.16b, v22.16b, v11.16b\n"
+ "and v15.16b, v23.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v20.4s, v20.4s, v8.4s\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v14.4s\n"
+ "sqadd v23.4s, v23.4s, v15.4s\n"
+ "and v8.16b, v24.16b, v11.16b\n"
+ "and v9.16b, v25.16b, v12.16b\n"
+ "and v14.16b, v26.16b, v11.16b\n"
+ "and v15.16b, v27.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v8.4s\n"
+ "sqadd v25.4s, v25.4s, v9.4s\n"
+ "sqadd v26.4s, v26.4s, v14.4s\n"
+ "sqadd v27.4s, v27.4s, v15.4s\n"
+ "and v8.16b, v28.16b, v11.16b\n"
+ "and v9.16b, v29.16b, v12.16b\n"
+ "and v14.16b, v30.16b, v11.16b\n"
+ "and v15.16b, v31.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v28.4s, v28.4s, v8.4s\n"
+ "sqadd v29.4s, v29.4s, v9.4s\n"
+ "sqadd v30.4s, v30.4s, v14.4s\n"
+ "sqadd v31.4s, v31.4s, v15.4s\n"
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
+ "srshl v27.4s, v27.4s, v12.4s\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+ "sqxtun v17.8b, v18.8h\n"
+ "sqxtun2 v17.16b, v19.8h\n"
+ "sqxtun v18.8b, v20.8h\n"
+ "sqxtun2 v18.16b, v21.8h\n"
+ "sqxtun v19.8b, v22.8h\n"
+ "sqxtun2 v19.16b, v23.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ "umax v17.16b, v17.16b, v14.16b\n"
+ "umax v18.16b, v18.16b, v14.16b\n"
+ "umax v19.16b, v19.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ "umin v17.16b, v17.16b, v15.16b\n"
+ "umin v18.16b, v18.16b, v15.16b\n"
+ "umin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+ "sqxtn v17.8b, v18.8h\n"
+ "sqxtn2 v17.16b, v19.8h\n"
+ "sqxtn v18.8b, v20.8h\n"
+ "sqxtn2 v18.16b, v21.8h\n"
+ "sqxtn v19.8b, v22.8h\n"
+ "sqxtn2 v19.16b, v23.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ "smax v17.16b, v17.16b, v14.16b\n"
+ "smax v18.16b, v18.16b, v14.16b\n"
+ "smax v19.16b, v19.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ "smin v17.16b, v17.16b, v15.16b\n"
+ "smin v18.16b, v18.16b, v15.16b\n"
+ "smin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 150b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+ "saddw v20.4s, v20.4s, v14.4h\n"
+ "saddw v21.4s, v21.4s, v14.4h\n"
+ "saddw v22.4s, v22.4s, v14.4h\n"
+ "saddw v23.4s, v23.4s, v14.4h\n"
+ "saddw v24.4s, v24.4s, v14.4h\n"
+ "saddw v25.4s, v25.4s, v14.4h\n"
+ "saddw v26.4s, v26.4s, v14.4h\n"
+ "saddw v27.4s, v27.4s, v14.4h\n"
+ "saddw v28.4s, v28.4s, v14.4h\n"
+ "saddw v29.4s, v29.4s, v14.4h\n"
+ "saddw v30.4s, v30.4s, v14.4h\n"
+ "saddw v31.4s, v31.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ "smax v18.8h, v18.8h, v14.8h\n"
+ "smax v19.8h, v19.8h, v14.8h\n"
+ "smax v20.8h, v20.8h, v14.8h\n"
+ "smax v21.8h, v21.8h, v14.8h\n"
+ "smax v22.8h, v22.8h, v14.8h\n"
+ "smax v23.8h, v23.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ "smin v18.8h, v18.8h, v15.8h\n"
+ "smin v19.8h, v19.8h, v15.8h\n"
+ "smin v20.8h, v20.8h, v15.8h\n"
+ "smin v21.8h, v21.8h, v15.8h\n"
+ "smin v22.8h, v22.8h, v15.8h\n"
+ "smin v23.8h, v23.8h, v15.8h\n"
+
+ // Compute how much of the 8x8 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 16bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 250b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x8 block of destination 32it values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x8 block fits.
+ // Write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "st1 {v16.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v16)
+ "st1 {v17.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "st1 {v18.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "st1 {v19.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "st1 {v20.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v20)
+ "st1 {v21.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v21)
+ "st1 {v22.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v22)
+ "st1 {v23.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v23)
+ "st1 {v24.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "st1 {v25.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "st1 {v26.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "st1 {v27.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "st1 {v28.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "st1 {v29.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "st1 {v30.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "st1 {v31.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x8 block fits.
+ "mov x4, %[dst_ptr]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v16.4s, v17.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v18.4s, v19.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v20.4s, v21.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v22.4s, v23.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v24.4s, v25.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v26.4s, v27.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v28.4s, v29.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "add x4, x4, x11\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov x3, x4\n"
+ "st1 {v30.4s, v31.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 350b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Similar to the above 8-bit dotprod kernel, but specialized for the case of
+// RHS cols == 1.
+// Relevant target CPUs for this kernel include ARM Cortex-A76,
+// since these are 64-bit, out-of-order and with dotprod support.
+void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for out-of-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v15 are used to load int8 data from LHS and
+ // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
+ // v3 are used to load a 4x8 block of RHS, like this:
+ //
+ // int8 RHS 4x1 block
+ // /-------\
+ // |v2.b[0]|
+ // | ... |
+ // |v2.b[3]|
+ // \-------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /--------\
+ // |v0.b[0] ... v0.b[3] | |v16.s[0]|
+ // | ... ... | | ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0]|
+ // | ... ... | | ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3]|
+ // \---------------------/ \--------/
+ // int32 accumulators 8x1 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for loading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ // Ordinary kernel inner loop (over depth), the simpler loop that the
+ // above was an equivalent 4x-partially-unrolled version of.
+
+ // Reminder - w1 is how many levels of depth we have already loaded
+ // data for, w12 is the total depth.
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+
+ // Because of the data that we have already loaded, we can start the
+ // loop body right away with some multiply-adds.
+ // Each iteration of this loop advances by 4 levels of depth.
+ "add w1, w1, #4\n"
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ // Loop termination condition.
+ "cmp w1, w12\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+
+ "blt 2b\n"
+
+ "79:\n"
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last 4 levels of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ld1 {v14.4s}, [x3], #16\n"
+ "ld1 {v15.4s}, [x3]\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ // Load 4 lhs_sums values.
+ "ld1 {v11.4s}, [x2], #16\n"
+ "ld1 {v12.4s}, [x2]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
+ "beq 403f\n"
+ "smax v11.4s, v9.4s, v8.4s\n"
+ "smax v12.4s, v10.4s, v8.4s\n"
+ "sshl v16.4s, v16.4s, v11.4s\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "403:\n"
+
+ "ldr q14, [x4]\n" // multiplier_fixedpoint
+ "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
+
+ "smin v11.4s, v9.4s, v8.4s\n"
+ "smin v12.4s, v10.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v11.16b\n"
+ "and v9.16b, v17.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+ "sqadd v17.4s, v17.4s, v9.4s\n"
+
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ // All data in v16 at this point.
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8, leaving all data in the
+ // lower half of v16.
+ "sqxtun v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+
+ // Compute how much of the 8x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x1, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x1 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8b}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "add v16.8h, v16.8h, v14.8h\n"
+
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+
+ // Compute how much of the 8x1 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x1 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8b}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+
+ // Compute how much of the 8x1 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+
+ // Test if w1==8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 16bit values to the destination
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.8h}, [x3]\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x1 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x1 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x1 block of destination 32 bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x1, there are some 8x1 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x1 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ // Yes, all of the 8x1 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+
+ // Write our 32bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.4s}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.4s}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x1 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x4, %[dst_ptr]\n"
+ "mov x3, x4\n"
+
+ // Write our 32bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.4s, v17.4s}, [x3], #32\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 4.
+ "mov w1, #4\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17");
+}
+
+// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order
+// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1,
+// since these are 64-bit and support dotprod.
+//
+// While this kernel does not have a direct equivalent in gemmlowp, it was
+// developed based on insights that David Mansell at ARM shared with their
+// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A55r1:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
+void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParams8bit(params);
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are int32 accumulators.
+ // During accumulation, v0 -- v3 are used to load int8 data from LHS and
+ // RHS.
+ //
+ // int8 RHS 4x8 block
+ // /-----------------------------------------\
+ // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
+ // | ... ... |
+ // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
+ // \-----------------------------------------/
+ // int8 LHS 8x4 block
+ // /---------------------\ /-----------------------------------------\
+ // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
+ // | ... ... | | ... ... |
+ // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
+ // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
+ // | ... ... | | ... ... |
+ // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // int32 accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for
+ // the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ RUY_MAKE_ZERO(v16)
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ RUY_MAKE_ZERO(v17)
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ RUY_MAKE_ZERO(v18)
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ RUY_MAKE_ZERO(v19)
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v20)
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v21)
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ RUY_MAKE_ZERO(v22)
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ // Perform the first few multiply-adds on the data that we have already
+ // loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_MAKE_ZERO(v28)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_MAKE_ZERO(v29)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_MAKE_ZERO(v30)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_MAKE_ZERO(v31)
+
+
+ "1:\n"
+
+ "add x5, %[lhs_ptr], x12, lsl #3\n"
+ "sub x5, x5, #32\n"
+ "cmp %[lhs_ptr], x5\n"
+
+ "beq 79f\n"
+
+ // Main accumulation loop
+ "2:\n"
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ "ldr x1, [%[lhs_ptr], #8]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ "ldr x3, [%[rhs_ptr], #8]\n"
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ "ldr x4, [%[rhs_ptr], #24]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ "ldr d0, [%[lhs_ptr], #0]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ "ins v0.d[1], x1\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #32\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ "ldr d2, [%[rhs_ptr], #0]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ "ins v2.d[1], x3\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ "cmp %[lhs_ptr], x5\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ "add %[rhs_ptr], %[rhs_ptr], #32\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+ "ldr d3, [%[rhs_ptr], #-16]\n"
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ "ins v3.d[1], x4\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ "ins v1.d[1], x2\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ "blt 2b\n"
+
+ // Last accumulation steps, nothing left to load.
+ "79:\n"
+ ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
+ ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
+ ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
+ ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
+ ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
+ ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
+ ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
+ ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
+ ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
+ ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ // Load some parameters needed for the end work on current block.
+ RUY_MAKE_ZERO(v8)
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+ "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
+ "ins v13.h[4], w4\n" // dst_zero_point
+ "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "dup v9.4s, w3\n" // create prod_zp_depth_vec
+ "add x5, x4, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.2s}, [x1], #8\n"
+ "ldr x5, [x1], #8\n"
+ "ins v14.d[1], x5\n"
+ "ld1 {v15.2s}, [x1], #8\n"
+ "ldr x5, [x1], #8\n"
+ "ins v15.d[1], x5\n"
+
+ // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
+ // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "add v14.4s, v14.4s, v9.4s\n"
+ "add v15.4s, v15.4s, v9.4s\n"
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "add v16.4s, v16.4s, v14.4s\n"
+ "add v17.4s, v17.4s, v15.4s\n"
+ "add v18.4s, v18.4s, v14.4s\n"
+ "add v19.4s, v19.4s, v15.4s\n"
+ "add v20.4s, v20.4s, v14.4s\n"
+ "add v21.4s, v21.4s, v15.4s\n"
+ "add v22.4s, v22.4s, v14.4s\n"
+ "add v23.4s, v23.4s, v15.4s\n"
+ "add v24.4s, v24.4s, v14.4s\n"
+ "add v25.4s, v25.4s, v15.4s\n"
+ "add v26.4s, v26.4s, v14.4s\n"
+ "add v27.4s, v27.4s, v15.4s\n"
+ "add v28.4s, v28.4s, v14.4s\n"
+ "add v29.4s, v29.4s, v15.4s\n"
+ "add v30.4s, v30.4s, v14.4s\n"
+ "add v31.4s, v31.4s, v15.4s\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
+ "beq 401f\n"
+ "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
+ "add x3, x3, %x[col], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
+ "dup v10.4s, w5\n" // create lhs_zero_point_vec
+ // Load 8 rhs_sums values.
+ "ld1 {v14.2s}, [x3], #8\n"
+ "ldr x7, [x3], #8\n"
+ "ld1 {v15.2s}, [x3], #8\n"
+ "ins v14.d[1], x7\n"
+ "ldr x7, [x3], #8\n"
+ "ins v15.d[1], x7\n"
+ // Subtract rhs_sums * lhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "mls v16.4s, v10.4s, v14.s[0]\n"
+ "mls v17.4s, v10.4s, v14.s[0]\n"
+ "mls v18.4s, v10.4s, v14.s[1]\n"
+ "mls v19.4s, v10.4s, v14.s[1]\n"
+ "mls v20.4s, v10.4s, v14.s[2]\n"
+ "mls v21.4s, v10.4s, v14.s[2]\n"
+ "mls v22.4s, v10.4s, v14.s[3]\n"
+ "mls v23.4s, v10.4s, v14.s[3]\n"
+ "mls v24.4s, v10.4s, v15.s[0]\n"
+ "mls v25.4s, v10.4s, v15.s[0]\n"
+ "mls v26.4s, v10.4s, v15.s[1]\n"
+ "mls v27.4s, v10.4s, v15.s[1]\n"
+ "mls v28.4s, v10.4s, v15.s[2]\n"
+ "mls v29.4s, v10.4s, v15.s[2]\n"
+ "mls v30.4s, v10.4s, v15.s[3]\n"
+ "mls v31.4s, v10.4s, v15.s[3]\n"
+ "401:\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
+ "beq 402f\n"
+ "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
+ "add x2, x2, %x[row], lsl #2\n"
+ "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
+ "ins v13.s[1], w5\n" // rhs_zero_point
+ // Load 8 lhs_sums values.
+ "ld1 {v11.2s}, [x2], #8\n"
+ "ldr x6, [x2], #8\n"
+ "ins v11.d[1], x6\n"
+ "ld1 {v12.2s}, [x2], #8\n"
+ "ldr x6, [x2], #8\n"
+ "ins v12.d[1], x6\n"
+ // Compute lhs_sums * rhs_zero_point.
+ "mul v11.4s, v11.4s, v13.s[1]\n"
+ "mul v12.4s, v12.4s, v13.s[1]\n"
+ // Subtract lhs_sums * rhs_zero_point, per
+ // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
+ "sub v16.4s, v16.4s, v11.4s\n"
+ "sub v17.4s, v17.4s, v12.4s\n"
+ "sub v18.4s, v18.4s, v11.4s\n"
+ "sub v19.4s, v19.4s, v12.4s\n"
+ "sub v20.4s, v20.4s, v11.4s\n"
+ "sub v21.4s, v21.4s, v12.4s\n"
+ "sub v22.4s, v22.4s, v11.4s\n"
+ "sub v23.4s, v23.4s, v12.4s\n"
+ "sub v24.4s, v24.4s, v11.4s\n"
+ "sub v25.4s, v25.4s, v12.4s\n"
+ "sub v26.4s, v26.4s, v11.4s\n"
+ "sub v27.4s, v27.4s, v12.4s\n"
+ "sub v28.4s, v28.4s, v11.4s\n"
+ "sub v29.4s, v29.4s, v12.4s\n"
+ "sub v30.4s, v30.4s, v11.4s\n"
+ "sub v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
+
+ "402:\n"
+
+ // At this point we have computed the final int32 values. Now we
+ // start down-quantizing them to obtain the final 8bit values from them.
+
+ // As part of this down-quantization, our int32 values will be
+ // multiplied by a multiplier that has a fixed-point component and an
+ // exponent component.
+
+ //Load the exponent part of the multiplier.
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "add x5, x1, %x[row], lsl #2\n"
+ "csel x1, x1, x5, eq\n"
+
+ "ldr q9, [x1]\n"
+ "ldr q10, [x1, #16]\n"
+
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
+ "beq 403f\n"
+ "smax v11.4s, v9.4s, v8.4s\n"
+ "smax v12.4s, v10.4s, v8.4s\n"
+ "sshl v16.4s, v16.4s, v11.4s\n"
+ "sshl v17.4s, v17.4s, v12.4s\n"
+ "sshl v18.4s, v18.4s, v11.4s\n"
+ "sshl v19.4s, v19.4s, v12.4s\n"
+ "sshl v20.4s, v20.4s, v11.4s\n"
+ "sshl v21.4s, v21.4s, v12.4s\n"
+ "sshl v22.4s, v22.4s, v11.4s\n"
+ "sshl v23.4s, v23.4s, v12.4s\n"
+ "sshl v24.4s, v24.4s, v11.4s\n"
+ "sshl v25.4s, v25.4s, v12.4s\n"
+ "sshl v26.4s, v26.4s, v11.4s\n"
+ "sshl v27.4s, v27.4s, v12.4s\n"
+ "sshl v28.4s, v28.4s, v11.4s\n"
+ "sshl v29.4s, v29.4s, v12.4s\n"
+ "sshl v30.4s, v30.4s, v11.4s\n"
+ "sshl v31.4s, v31.4s, v12.4s\n"
+ "403:\n"
+
+ "ldr q14, [x4]\n" // multiplier_fixedpoint
+ "ldr q15, [x4, #16]\n" // multiplier_fixedpoint
+
+ "smin v11.4s, v9.4s, v8.4s\n"
+ "smin v12.4s, v10.4s, v8.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ //
+ // ... and, interleaved into that:
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
+ "sqrdmulh v16.4s, v16.4s, v14.4s\n"
+ "ldr x1, [%[lhs_ptr]], #8\n"
+ "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
+ "sqrdmulh v18.4s, v18.4s, v14.4s\n"
+ "ldr x2, [%[lhs_ptr]], #8\n"
+ "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
+ "sqrdmulh v20.4s, v20.4s, v14.4s\n"
+ "ldr x5, [%[rhs_ptr]], #8\n"
+ "sqrdmulh v21.4s, v21.4s, v15.4s\n"
+ "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
+ "sqrdmulh v22.4s, v22.4s, v14.4s\n"
+ "ldr x6, [%[rhs_ptr]], #8\n"
+ "sqrdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqrdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqrdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqrdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqrdmulh v31.4s, v31.4s, v15.4s\n"
+
+ // We have some rounding division-by-power-of-two to do. This should
+ // always use "round to nearest". We allow for some
+ // freedom in how ties are broken, to strike a good compromise of
+ // performance on given hardware vs. perfect agreement of results
+ // across hardware.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
+ // defined tie-breaks to help performance. On NEON, this means that we
+ // can just use the NEON rounding instructions, such as srshl. They
+ // happen to be breaking ties upward.
+ //
+ // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
+ // break-ties-away-from zero, as described in Appendix B of
+ // https://arxiv.org/pdf/1712.05877.pdf
+ // When we wrote that, we thought that that would be better unbiased
+ // than the NEON upwards tie-breaks, and we had observed some
+ // improvement on some model. However, that is only more unbiased for
+ // data centered at zero, which was likely the case in that model,
+ // but is not always the case. If we wanted something more consistently
+ // unbiased then we should try breaking ties toward-nearest-even.
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ // Fix up values to be right-shifted, so that the (round to nearest,
+ // break ties upward) behavior of srshl applied to these fixed-up
+ // values, produces the same result as the desired (round to nearest,
+ // break ties away from zero) behavior on the original values.
+ "and v8.16b, v16.16b, v11.16b\n"
+ "and v9.16b, v17.16b, v12.16b\n"
+ "and v14.16b, v18.16b, v11.16b\n"
+ "and v15.16b, v19.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v8.4s\n"
+ "sqadd v17.4s, v17.4s, v9.4s\n"
+ "sqadd v18.4s, v18.4s, v14.4s\n"
+ "sqadd v19.4s, v19.4s, v15.4s\n"
+ "and v8.16b, v20.16b, v11.16b\n"
+ "and v9.16b, v21.16b, v12.16b\n"
+ "and v14.16b, v22.16b, v11.16b\n"
+ "and v15.16b, v23.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v20.4s, v20.4s, v8.4s\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v14.4s\n"
+ "sqadd v23.4s, v23.4s, v15.4s\n"
+ "and v8.16b, v24.16b, v11.16b\n"
+ "and v9.16b, v25.16b, v12.16b\n"
+ "and v14.16b, v26.16b, v11.16b\n"
+ "and v15.16b, v27.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v8.4s\n"
+ "sqadd v25.4s, v25.4s, v9.4s\n"
+ "sqadd v26.4s, v26.4s, v14.4s\n"
+ "sqadd v27.4s, v27.4s, v15.4s\n"
+ "and v8.16b, v28.16b, v11.16b\n"
+ "and v9.16b, v29.16b, v12.16b\n"
+ "and v14.16b, v30.16b, v11.16b\n"
+ "and v15.16b, v31.16b, v12.16b\n"
+ "sshr v8.4s, v8.4s, #31\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v14.4s, v14.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sqadd v28.4s, v28.4s, v8.4s\n"
+ "sqadd v29.4s, v29.4s, v9.4s\n"
+ "sqadd v30.4s, v30.4s, v14.4s\n"
+ "sqadd v31.4s, v31.4s, v15.4s\n"
+#endif
+ // At this point we have reduced the problem of correctly implementing
+ // rounding divide-by-power-of-two, to what the SRSHL instruction can
+ // do.
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
+ "srshl v27.4s, v27.4s, v12.4s\n"
+ "ins v0.d[1], x1\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
+ "ins v1.d[1], x2\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
+ "ins v2.d[1], x5\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
+ "ins v3.d[1], x6\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
+
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
+ "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
+ "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // Destination zero_point
+ "dup v14.8h, v13.h[4]\n"
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ // Cast-and-saturate from int16 to uint8
+ "sqxtun v16.8b, v16.8h\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "sqxtun2 v16.16b, v17.8h\n"
+ "sqxtun v17.8b, v18.8h\n"
+ "sqxtun2 v17.16b, v19.8h\n"
+ "sqxtun v18.8b, v20.8h\n"
+ "sqxtun2 v18.16b, v21.8h\n"
+ "sqxtun v19.8b, v22.8h\n"
+ "sqxtun2 v19.16b, v23.8h\n"
+
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ // Apply the clamp_min bound
+ "umax v16.16b, v16.16b, v14.16b\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "umax v17.16b, v17.16b, v14.16b\n"
+ "mov w3, #8\n"
+ "umax v18.16b, v18.16b, v14.16b\n"
+ "cmp w1, #8\n"
+ "umax v19.16b, v19.16b, v14.16b\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ // Apply the clamp_max bound
+ "umin v16.16b, v16.16b, v15.16b\n"
+ "cmp w2, #8\n"
+ "umin v17.16b, v17.16b, v15.16b\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+ "umin v18.16b, v18.16b, v15.16b\n"
+ "umin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // Destination zero_point
+ "dup v14.8h, v13.h[4]\n"
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Add the destination zero point
+ "add v16.8h, v16.8h, v14.8h\n"
+ "add v17.8h, v17.8h, v14.8h\n"
+ "add v18.8h, v18.8h, v14.8h\n"
+ "add v19.8h, v19.8h, v14.8h\n"
+ "add v20.8h, v20.8h, v14.8h\n"
+ "add v21.8h, v21.8h, v14.8h\n"
+ "add v22.8h, v22.8h, v14.8h\n"
+ "add v23.8h, v23.8h, v14.8h\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ // Cast-and-saturate from int16 to uint8
+ "sqxtn v16.8b, v16.8h\n"
+ "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "sqxtn2 v16.16b, v17.8h\n"
+ "sqxtn v17.8b, v18.8h\n"
+ "sqxtn2 v17.16b, v19.8h\n"
+ "sqxtn v18.8b, v20.8h\n"
+ "sqxtn2 v18.16b, v21.8h\n"
+ "sqxtn v19.8b, v22.8h\n"
+ "sqxtn2 v19.16b, v23.8h\n"
+
+ "dup v14.16b, w2\n" // clamp_min
+ "dup v15.16b, w3\n" // clamp_max
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ // Apply the clamp_min bound
+ "smax v16.16b, v16.16b, v14.16b\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "smax v17.16b, v17.16b, v14.16b\n"
+ "mov w3, #8\n"
+ "smax v18.16b, v18.16b, v14.16b\n"
+ "cmp w1, #8\n"
+ "smax v19.16b, v19.16b, v14.16b\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ // Apply the clamp_max bound
+ "smin v16.16b, v16.16b, v15.16b\n"
+ "cmp w2, #8\n"
+ "smin v17.16b, v17.16b, v15.16b\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+ "smin v18.16b, v18.16b, v15.16b\n"
+ "smin v19.16b, v19.16b, v15.16b\n"
+
+ // Make it so that all of the final 8bit values are stored in the
+ // first 64bits of 128bit NEON registers, so they can be stored
+ // by 64bit st1 store instructions with byte alignment.
+ "dup d20, v16.d[1]\n"
+ "dup d21, v17.d[1]\n"
+ "dup d22, v18.d[1]\n"
+ "dup d23, v19.d[1]\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 130f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #8\n"
+ "b 131f\n"
+ "130:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "131:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8b}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 141f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "150:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "151:\n"
+ "ldrb w7, [x3, w5, uxtw]\n"
+ "strb w7, [x4, w5, uxtw]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 151b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #8\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 150b\n"
+ "141:\n"
+ "add %[dst_ptr], %[dst_ptr], #8\n"
+
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
+
+ // Add the destination zero point
+ "dup v14.8h, v13.h[4]\n"
+ "saddw v16.4s, v16.4s, v14.4h\n"
+ "saddw v17.4s, v17.4s, v14.4h\n"
+ "saddw v18.4s, v18.4s, v14.4h\n"
+ "saddw v19.4s, v19.4s, v14.4h\n"
+ "saddw v20.4s, v20.4s, v14.4h\n"
+ "saddw v21.4s, v21.4s, v14.4h\n"
+ "saddw v22.4s, v22.4s, v14.4h\n"
+ "saddw v23.4s, v23.4s, v14.4h\n"
+ "saddw v24.4s, v24.4s, v14.4h\n"
+ "saddw v25.4s, v25.4s, v14.4h\n"
+ "saddw v26.4s, v26.4s, v14.4h\n"
+ "saddw v27.4s, v27.4s, v14.4h\n"
+ "saddw v28.4s, v28.4s, v14.4h\n"
+ "saddw v29.4s, v29.4s, v14.4h\n"
+ "saddw v30.4s, v30.4s, v14.4h\n"
+ "saddw v31.4s, v31.4s, v14.4h\n"
+
+ // Cast-and-saturate from int32 to int16
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtn v17.4h, v18.4s\n"
+ "sqxtn2 v17.8h, v19.4s\n"
+ "sqxtn v18.4h, v20.4s\n"
+ "sqxtn2 v18.8h, v21.4s\n"
+ "sqxtn v19.4h, v22.4s\n"
+ "sqxtn2 v19.8h, v23.4s\n"
+ "sqxtn v20.4h, v24.4s\n"
+ "sqxtn2 v20.8h, v25.4s\n"
+ "sqxtn v21.4h, v26.4s\n"
+ "sqxtn2 v21.8h, v27.4s\n"
+ "sqxtn v22.4h, v28.4s\n"
+ "sqxtn2 v22.8h, v29.4s\n"
+ "sqxtn v23.4h, v30.4s\n"
+ "sqxtn2 v23.8h, v31.4s\n"
+
+ // At this point, v24 -- v31 aren't used anymore for the current block,
+ // so we can start clearing these accumulators for the next block
+ // (next iteration of the main loop).
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // Load the clamp_min, clamp_max bounds
+ "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.8h, w2\n" // clamp_min
+ "dup v15.8h, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "smax v16.8h, v16.8h, v14.8h\n"
+ "smax v17.8h, v17.8h, v14.8h\n"
+ "smax v18.8h, v18.8h, v14.8h\n"
+ "smax v19.8h, v19.8h, v14.8h\n"
+ "smax v20.8h, v20.8h, v14.8h\n"
+ "smax v21.8h, v21.8h, v14.8h\n"
+ "smax v22.8h, v22.8h, v14.8h\n"
+ "smax v23.8h, v23.8h, v14.8h\n"
+ // Apply the clamp_max bound
+ "smin v16.8h, v16.8h, v15.8h\n"
+ "smin v17.8h, v17.8h, v15.8h\n"
+ "smin v18.8h, v18.8h, v15.8h\n"
+ "smin v19.8h, v19.8h, v15.8h\n"
+ "smin v20.8h, v20.8h, v15.8h\n"
+ "smin v21.8h, v21.8h, v15.8h\n"
+ "smin v22.8h, v22.8h, v15.8h\n"
+ "smin v23.8h, v23.8h, v15.8h\n"
+
+ // Compute how much of the 8x8 block of destination 16bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 230f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #16\n"
+ "b 231f\n"
+ "230:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "231:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v16.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v17.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v18.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v19.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v20.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v21.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v22.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "st1 {v23.8h}, [x3], x4\n"
+ RUY_MAKE_ZERO(v23)
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 241f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "250:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "251:\n"
+ "ldrsh w7, [x3, x5, lsl #1]\n"
+ "strh w7, [x4, x5, lsl #1]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 251b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #16\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 250b\n"
+ "241:\n"
+ "add %[dst_ptr], %[dst_ptr], #16\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
+
+ RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
+
+ "ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
+ "ldr x1, [%[lhs_ptr]], #8\n"
+ "ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
+ "ldr x2, [%[lhs_ptr]], #8\n"
+ "ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
+ "ldr x5, [%[rhs_ptr]], #8\n"
+ "ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
+ "ldr x6, [%[rhs_ptr]], #8\n"
+ "ins v0.d[1], x1\n"
+ "ins v1.d[1], x2\n"
+ "ins v2.d[1], x5\n"
+ "ins v3.d[1], x6\n"
+
+ // Since the store type is the same as the accum type, no need for
+ // downcast. There's also no need for clamp by min/max.
+
+ // Compute how much of the 8x8 block of destination 32it values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 330f\n"
+ // Not all of the 8x8 block fits.
+ // Write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "st1 {v16.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v16)
+ "st1 {v17.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "st1 {v18.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "st1 {v19.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "st1 {v20.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v20)
+ "st1 {v21.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v21)
+ "st1 {v22.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v22)
+ "st1 {v23.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v23)
+ "st1 {v24.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v24)
+ "st1 {v25.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v25)
+ "st1 {v26.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v26)
+ "st1 {v27.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v27)
+ "st1 {v28.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v28)
+ "st1 {v29.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v29)
+ "st1 {v30.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v30)
+ "st1 {v31.4s}, [x3], #16\n"
+ RUY_MAKE_ZERO(v31)
+
+ "b 331f\n"
+
+ "330:\n"
+ // Yes, all of the 8x8 block fits.
+ "mov x4, %[dst_ptr]\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v16.4s, v17.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v18.4s, v19.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v20.4s, v21.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v22.4s, v23.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v24.4s, v25.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v26.4s, v27.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v28.4s, v29.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "st1 {v30.4s, v31.4s}, [x4], x11\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ "331:\n"
+
+ // For the next block: perform the first few multiply-adds on the data
+ // that we have already loaded.
+ ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
+ ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
+ ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
+ ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 341f\n"
+
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "350:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "351:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 351b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 350b\n"
+ "341:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
+ [dst_type_id] "r"(params.dst_type_id)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_LHS_SUMS
+#undef RUY_OFFSET_RHS_SUMS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
+#undef RUY_OFFSET_MULTIPLIER_EXPONENT
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+#undef RUY_OFFSET_LHS_ZERO_POINT
+#undef RUY_OFFSET_RHS_ZERO_POINT
+#undef RUY_OFFSET_DST_ZERO_POINT
+#undef RUY_OFFSET_PROD_ZP_DEPTH
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_DST_ROWS
+#undef RUY_OFFSET_DST_COLS
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_FLAGS
+
+#define RUY_OFFSET_LHS_BASE_PTR 0
+#define RUY_OFFSET_RHS_BASE_PTR 8
+#define RUY_OFFSET_DST_BASE_PTR 16
+#define RUY_OFFSET_BIAS 24
+#define RUY_OFFSET_START_ROW 32
+#define RUY_OFFSET_START_COL 36
+#define RUY_OFFSET_LAST_ROW 40
+#define RUY_OFFSET_LAST_COL 44
+#define RUY_OFFSET_LHS_STRIDE 56
+#define RUY_OFFSET_RHS_STRIDE 60
+#define RUY_OFFSET_DST_STRIDE 64
+#define RUY_OFFSET_DEPTH 68
+#define RUY_OFFSET_CLAMP_MIN 72
+#define RUY_OFFSET_CLAMP_MAX 76
+#define RUY_OFFSET_FLAGS 80
+
+template <typename Params>
+void CheckOffsetsInKernelParamsFloat(const Params&) {
+ static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
+ static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
+ static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
+ static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
+ static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
+ static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
+ static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
+ static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
+ static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
+ static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
+ static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
+ static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
+ static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
+ static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
+ static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
+}
+
+// Just a plain float kernel; good enough for out-of-order cores.
+// The closest to it in the gemmlowp collection would be
+// NEON_64bit_GEMM_Float32_WithScalar,
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925
+//
+// Besides ruy-ification, the main nuance here is that we stick to a 8x8
+// width instead of the wider 12x8 that the register space permits and that
+// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now
+// and we don't have evidence that going beyond 8x8 is needed.
+void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) {
+ CheckOffsetsInKernelParamsFloat(params);
+ profiler::ScopeLabel label(
+ "Kernel (kNeon, optimized for out-of-order cores)");
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v15 are used to load data from LHS and RHS.
+ // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and
+ // v3 are used to load a 1x8 block of RHS, like this:
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------\
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------\
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
+ // is repeated 4 times, using 4x more registers for LHS and RHS, so that
+ // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
+ //
+ // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
+ // unused, and v8 -- v15 are used for floading parameters used for the
+ // post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov w1, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+ "cmp w12, #8\n"
+ "blt 78f\n"
+ "and w2, w12, #-4\n"
+
+ "ld1 {v4.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v5.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v6.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v7.4s}, [%[rhs_ptr]], #16\n"
+
+ "ld1 {v8.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v9.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v10.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v11.4s}, [%[rhs_ptr]], #16\n"
+
+ "ld1 {v12.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v13.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v14.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v15.4s}, [%[rhs_ptr]], #16\n"
+ "mov w1, #4\n"
+
+ "80:\n"
+
+ "add %[lhs_ptr], %[lhs_ptr], #128\n"
+ "add %[rhs_ptr], %[rhs_ptr], #128\n"
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr q0, [%[lhs_ptr], #-128]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr q3, [%[rhs_ptr], #-112]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "ldr q1, [%[lhs_ptr], #-112]\n"
+ "fmla v16.4s, v4.4s, v6.s[0]\n"
+ "fmla v18.4s, v4.4s, v6.s[1]\n"
+ "ldr q2, [%[rhs_ptr], #-128]\n"
+ "fmla v20.4s, v4.4s, v6.s[2]\n"
+ "fmla v22.4s, v4.4s, v6.s[3]\n"
+
+ "fmla v24.4s, v4.4s, v7.s[0]\n"
+ "fmla v26.4s, v4.4s, v7.s[1]\n"
+ "fmla v28.4s, v4.4s, v7.s[2]\n"
+ "fmla v30.4s, v4.4s, v7.s[3]\n"
+ "ldr q4, [%[lhs_ptr], #-96]\n"
+ "fmla v25.4s, v5.4s, v7.s[0]\n"
+ "fmla v27.4s, v5.4s, v7.s[1]\n"
+ "fmla v29.4s, v5.4s, v7.s[2]\n"
+ "fmla v31.4s, v5.4s, v7.s[3]\n"
+ "ldr q7, [%[rhs_ptr], #-80]\n"
+ "fmla v17.4s, v5.4s, v6.s[0]\n"
+ "fmla v19.4s, v5.4s, v6.s[1]\n"
+ "fmla v21.4s, v5.4s, v6.s[2]\n"
+ "fmla v23.4s, v5.4s, v6.s[3]\n"
+ "ldr q5, [%[lhs_ptr], #-80]\n"
+ "fmla v16.4s, v8.4s, v10.s[0]\n"
+ "fmla v18.4s, v8.4s, v10.s[1]\n"
+ "ldr q6, [%[rhs_ptr], #-96]\n"
+ "fmla v20.4s, v8.4s, v10.s[2]\n"
+ "fmla v22.4s, v8.4s, v10.s[3]\n"
+
+ "fmla v24.4s, v8.4s, v11.s[0]\n"
+ "fmla v26.4s, v8.4s, v11.s[1]\n"
+ "fmla v28.4s, v8.4s, v11.s[2]\n"
+ "fmla v30.4s, v8.4s, v11.s[3]\n"
+ "ldr q8, [%[lhs_ptr], #-64]\n"
+ "fmla v25.4s, v9.4s, v11.s[0]\n"
+ "fmla v27.4s, v9.4s, v11.s[1]\n"
+ "fmla v29.4s, v9.4s, v11.s[2]\n"
+ "fmla v31.4s, v9.4s, v11.s[3]\n"
+ "ldr q11, [%[rhs_ptr], #-48]\n"
+ "fmla v17.4s, v9.4s, v10.s[0]\n"
+ "fmla v19.4s, v9.4s, v10.s[1]\n"
+ "fmla v21.4s, v9.4s, v10.s[2]\n"
+ "fmla v23.4s, v9.4s, v10.s[3]\n"
+ "ldr q9, [%[lhs_ptr], #-48]\n"
+ "fmla v16.4s, v12.4s, v14.s[0]\n"
+ "fmla v18.4s, v12.4s, v14.s[1]\n"
+ "ldr q10, [%[rhs_ptr], #-64]\n"
+ "fmla v20.4s, v12.4s, v14.s[2]\n"
+ "fmla v22.4s, v12.4s, v14.s[3]\n"
+
+ "fmla v24.4s, v12.4s, v15.s[0]\n"
+ "fmla v26.4s, v12.4s, v15.s[1]\n"
+ "fmla v28.4s, v12.4s, v15.s[2]\n"
+ "fmla v30.4s, v12.4s, v15.s[3]\n"
+ "ldr q12, [%[lhs_ptr], #-32]\n"
+ "fmla v25.4s, v13.4s, v15.s[0]\n"
+ "fmla v27.4s, v13.4s, v15.s[1]\n"
+ "fmla v29.4s, v13.4s, v15.s[2]\n"
+ "fmla v31.4s, v13.4s, v15.s[3]\n"
+ "ldr q15, [%[rhs_ptr], #-16]\n"
+ "fmla v17.4s, v13.4s, v14.s[0]\n"
+ "fmla v19.4s, v13.4s, v14.s[1]\n"
+ "fmla v21.4s, v13.4s, v14.s[2]\n"
+ "fmla v23.4s, v13.4s, v14.s[3]\n"
+ "ldr q13, [%[lhs_ptr], #-16]\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "ldr q14, [%[rhs_ptr], #-32]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ "add w1, w1, #4\n"
+ "cmp w1, w2\n"
+ "blt 80b\n"
+
+ "fmla v16.4s, v4.4s, v6.s[0]\n"
+ "fmla v18.4s, v4.4s, v6.s[1]\n"
+ "fmla v20.4s, v4.4s, v6.s[2]\n"
+ "fmla v22.4s, v4.4s, v6.s[3]\n"
+ "fmla v24.4s, v4.4s, v7.s[0]\n"
+ "fmla v26.4s, v4.4s, v7.s[1]\n"
+ "fmla v28.4s, v4.4s, v7.s[2]\n"
+ "fmla v30.4s, v4.4s, v7.s[3]\n"
+ "fmla v25.4s, v5.4s, v7.s[0]\n"
+ "fmla v27.4s, v5.4s, v7.s[1]\n"
+ "fmla v29.4s, v5.4s, v7.s[2]\n"
+ "fmla v31.4s, v5.4s, v7.s[3]\n"
+ "fmla v17.4s, v5.4s, v6.s[0]\n"
+ "fmla v19.4s, v5.4s, v6.s[1]\n"
+ "fmla v21.4s, v5.4s, v6.s[2]\n"
+ "fmla v23.4s, v5.4s, v6.s[3]\n"
+
+ "fmla v16.4s, v8.4s, v10.s[0]\n"
+ "fmla v18.4s, v8.4s, v10.s[1]\n"
+ "fmla v20.4s, v8.4s, v10.s[2]\n"
+ "fmla v22.4s, v8.4s, v10.s[3]\n"
+ "fmla v24.4s, v8.4s, v11.s[0]\n"
+ "fmla v26.4s, v8.4s, v11.s[1]\n"
+ "fmla v28.4s, v8.4s, v11.s[2]\n"
+ "fmla v30.4s, v8.4s, v11.s[3]\n"
+ "fmla v25.4s, v9.4s, v11.s[0]\n"
+ "fmla v27.4s, v9.4s, v11.s[1]\n"
+ "fmla v29.4s, v9.4s, v11.s[2]\n"
+ "fmla v31.4s, v9.4s, v11.s[3]\n"
+ "fmla v17.4s, v9.4s, v10.s[0]\n"
+ "fmla v19.4s, v9.4s, v10.s[1]\n"
+ "fmla v21.4s, v9.4s, v10.s[2]\n"
+ "fmla v23.4s, v9.4s, v10.s[3]\n"
+
+ "fmla v16.4s, v12.4s, v14.s[0]\n"
+ "fmla v18.4s, v12.4s, v14.s[1]\n"
+ "fmla v20.4s, v12.4s, v14.s[2]\n"
+ "fmla v22.4s, v12.4s, v14.s[3]\n"
+ "fmla v24.4s, v12.4s, v15.s[0]\n"
+ "fmla v26.4s, v12.4s, v15.s[1]\n"
+ "fmla v28.4s, v12.4s, v15.s[2]\n"
+ "fmla v30.4s, v12.4s, v15.s[3]\n"
+ "fmla v25.4s, v13.4s, v15.s[0]\n"
+ "fmla v27.4s, v13.4s, v15.s[1]\n"
+ "fmla v29.4s, v13.4s, v15.s[2]\n"
+ "fmla v31.4s, v13.4s, v15.s[3]\n"
+ "fmla v17.4s, v13.4s, v14.s[0]\n"
+ "fmla v19.4s, v13.4s, v14.s[1]\n"
+ "fmla v21.4s, v13.4s, v14.s[2]\n"
+ "fmla v23.4s, v13.4s, v14.s[3]\n"
+
+ "78:\n"
+#endif
+
+ // Accumulation loop
+ "cmp w1, w12\n"
+ "beq 79f\n"
+
+ "2:\n"
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ld1 {v4.4s}, [%[rhs_ptr]], #16\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "add w1, w1, #1\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "cmp w1, w12\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "mov v2.16b, v4.16b\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "blt 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that we have already loaded
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently 1.
+ "mov w1, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not
+// support dotprod (while dotprod by itself is not relevant to floating-point,
+// this additional bit of information that we have about the target happens to
+// be useful here).
+//
+// So a typical target CPU here would be ARM Cortex-A53 or the original
+// Cortex-A55.
+//
+// This kernel is similar to and inspired by gemmlowp's
+// NEON_64bit_GEMM_Float32_WithScalar_A53.
+// which was contributed by David Mansell with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A53:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
+void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParamsFloat(params);
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------\
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------\
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
+ // for the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "cmp w1, #0\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ // Accumulation loop
+ "beq 79f\n"
+
+ "2:\n"
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #8]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ldr x3, [%[lhs_ptr], #24]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "ldr x5, [%[rhs_ptr], #24]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr x4, [%[rhs_ptr], #8]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "subs w1, w1, #1\n"
+ "ldr d0, [%[lhs_ptr]], #32\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ins v0.d[1], x2\n"
+ "ldr d3, [%[rhs_ptr], #16]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "ins v3.d[1], x5\n"
+ "ldr d4, [%[rhs_ptr]], #32\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "ins v4.d[1], x4\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "ins v1.d[1], x3\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ "mov v2.16b, v4.16b\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "bne 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do
+// support dotprod (while dotprod by itself is not relevant to floating-point,
+// this additional bit of information that we have about the target happens to
+// be useful here).
+//
+// So a typical target CPU here would be ARM Cortex-A55r1.
+//
+// This kernel is similar to and inspired by gemmlowp's
+// NEON_64bit_GEMM_Float32_WithScalar_A55r1.
+// which was contributed by David Mansell with very helpful
+// comments. Specifically, see this comment about tuning for Cortex-A55r1:
+// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412
+void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label(
+ "Kernel (kNeonDotprod, optimized for in-order cores)");
+
+ CheckOffsetsInKernelParamsFloat(params);
+
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ int row = params.start_row;
+ int col = params.start_col;
+
+ // The asm kernel below has the following NEON register allocation:
+ //
+ // v16 -- v31 are accumulators.
+ // During accumulation, v0 -- v3 are used to load data from LHS and RHS.
+ //
+ // RHS 1x8 block
+ // /-----------------------------------------\
+ // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
+ // \-----------------------------------------/
+ // LHS 8x1 block
+ // /---------------------\ /-----------------------------------------\
+ // | v0.s[0] | |v16.s[0] ... v30.s[0]|
+ // | ... | | ... ... |
+ // | v0.s[3] | |v16.s[3] ... v30.s[3]|
+ // | v1.s[0] | |v17.s[0] ... v31.s[0]|
+ // | ... | | ... ... |
+ // | v1.s[3] | |v17.s[3] ... v31.s[3]|
+ // \---------------------/ \-----------------------------------------/
+ // accumulators 8x8 block
+ //
+ // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because
+ // we did not observe a benefit of such partial unrolling on in-order CPUs.
+ //
+ // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used
+ // for the post-accumulation part of the kernel.
+ asm volatile(
+#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
+
+ // clang-format off
+
+ // Load some parameters into registers.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+ "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
+ "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
+ "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
+ "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
+ "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
+
+
+ // Clear accumulators.
+ RUY_MAKE_ZERO(v16)
+ // Load the first 32 bytes of LHS and RHS data.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v17)
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v18)
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v19)
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+ RUY_MAKE_ZERO(v20)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v21)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v23)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v25)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v27)
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ // Main loop of the whole GEMM, over rows and columns of the
+ // destination matrix.
+ "1:\n"
+
+ "cmp w1, #0\n"
+ "fmla v16.4s, v0.4s, v2.s[0]\n"
+ "fmla v18.4s, v0.4s, v2.s[1]\n"
+ "fmla v20.4s, v0.4s, v2.s[2]\n"
+ "fmla v22.4s, v0.4s, v2.s[3]\n"
+
+ // Accumulation loop
+ "beq 79f\n"
+
+ "2:\n"
+
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n")
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #8]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "ldr x3, [%[lhs_ptr], #24]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "ldr x5, [%[rhs_ptr], #24]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "ldr d0, [%[lhs_ptr]], #32\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "ldr x4, [%[rhs_ptr], #8]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "subs w1, w1, #1\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "ins v0.d[1], x2\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr d3, [%[rhs_ptr], #16]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "ins v3.d[1], x5\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "ldr d4, [%[rhs_ptr]], #32\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "ins v4.d[1], x4\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n")
+ "fmla v16.4s, v0.4s, v4.s[0]\n"
+ "ldr d1, [%[lhs_ptr], #-16]\n"
+ "fmla v18.4s, v0.4s, v4.s[1]\n"
+ "ins v1.d[1], x3\n"
+ "fmla v20.4s, v0.4s, v4.s[2]\n"
+ "mov v2.16b, v4.16b\n"
+ "fmla v22.4s, v0.4s, v4.s[3]\n"
+ "bne 2b\n"
+
+ "79:\n"
+
+ // End of the inner loop on depth. Now perform the remaining
+ // multiply-adds of the last level of depth, for which the LHS
+ // and RHS data is already loaded.
+
+ "fmla v24.4s, v0.4s, v3.s[0]\n"
+ "fmla v26.4s, v0.4s, v3.s[1]\n"
+ "fmla v28.4s, v0.4s, v3.s[2]\n"
+ "fmla v30.4s, v0.4s, v3.s[3]\n"
+ "fmla v25.4s, v1.4s, v3.s[0]\n"
+ "fmla v27.4s, v1.4s, v3.s[1]\n"
+ "fmla v29.4s, v1.4s, v3.s[2]\n"
+ "fmla v31.4s, v1.4s, v3.s[3]\n"
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "fmla v17.4s, v1.4s, v2.s[0]\n"
+ "fmla v19.4s, v1.4s, v2.s[1]\n"
+ "fmla v21.4s, v1.4s, v2.s[2]\n"
+ "fmla v23.4s, v1.4s, v2.s[3]\n"
+
+ // End of accumulation. The registers v16 -- v31 contain the final
+ // int32 accumulator values of the current 8x8 destination block.
+ // We now have to compute the final 8-bit values from these int32
+ // accumulators, and advance to the next 8x8 block. We intertwine
+ // these two aspects whenever possible for optimal pipelining, both
+ // at the data flow level (prefetch data for next block as early as
+ // possible) and instruction pipelining level (some of the next-block
+ // work can dual-issue with some of the final work on the current
+ // block).
+
+ // Logic to advance to the next block in preparation for the next
+ // iteration of the main loop. For now, we only want to compute
+ // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
+ // not yet ready to update the values of row and col, as we still need
+ // the current values for the rest of the work on the current block.
+
+ "cmp %w[row], w7\n" // Have we finished the last row?
+ "bge 4f\n" // If finished last row, go to 4
+ // Not finished last row: then advance to next row.
+ "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
+ "b 5f\n"
+ "4:\n" // Finished last row...
+ "mov %[lhs_col_ptr], x5\n" // Go back to first row
+ // Now we need to advance to the next column. If we already
+ // finished the last column, then in principle we are done, however
+ // we can't just return here, as we need to allow the end work of the
+ // current block to complete. The good news is that at this point it
+ // doesn't matter what data we load for the next column, since
+ // we will exit from the main loop below before actually storing
+ // anything computed from that data.
+ "cmp %w[col], w8\n" // Have we finished the last column?
+ "bge 5f\n" // If yes, just carry on without updating the column pointer.
+ // Not finished last column: then advance to next column.
+ "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
+ "5:\n"
+
+ // Set the LHS and RHS data pointers to the start of the columns just
+ // computed.
+ "mov %[lhs_ptr], %[lhs_col_ptr]\n"
+ "mov %[rhs_ptr], %[rhs_col_ptr]\n"
+
+ // Load some parameters needed for the end work on current block.
+ "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
+ "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
+
+ // Offset these base pointers as needed given the current row, col.
+ "add x5, x1, %x[row], lsl #2\n"
+
+ "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
+ "csel x1, x1, x5, eq\n"
+
+ // Load 8 bias values.
+ "ld1 {v14.4s}, [x1], #16\n"
+ "ld1 {v15.4s}, [x1]\n"
+
+ // Now that we know what LHS and RHS data the next iteration of the
+ // main loop will need to load, we start loading the first 32 bytes of
+ // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
+ // in the rest of the work on the current block.
+ "ld1 {v0.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v1.4s}, [%[lhs_ptr]], #16\n"
+ "ld1 {v2.4s}, [%[rhs_ptr]], #16\n"
+ "ld1 {v3.4s}, [%[rhs_ptr]], #16\n"
+
+ // Perform the bias-addition (per the above, we have just folded into
+ // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ "fadd v16.4s, v16.4s, v14.4s\n"
+ "fadd v17.4s, v17.4s, v15.4s\n"
+ "fadd v18.4s, v18.4s, v14.4s\n"
+ "fadd v19.4s, v19.4s, v15.4s\n"
+ "fadd v20.4s, v20.4s, v14.4s\n"
+ "fadd v21.4s, v21.4s, v15.4s\n"
+ "fadd v22.4s, v22.4s, v14.4s\n"
+ "fadd v23.4s, v23.4s, v15.4s\n"
+ "fadd v24.4s, v24.4s, v14.4s\n"
+ "fadd v25.4s, v25.4s, v15.4s\n"
+ "fadd v26.4s, v26.4s, v14.4s\n"
+ "fadd v27.4s, v27.4s, v15.4s\n"
+ "fadd v28.4s, v28.4s, v14.4s\n"
+ "fadd v29.4s, v29.4s, v15.4s\n"
+ "fadd v30.4s, v30.4s, v14.4s\n"
+ "fadd v31.4s, v31.4s, v15.4s\n"
+
+ // Load the clamp_min, clamp_max bounds
+ "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
+ "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
+ "dup v14.4s, w2\n" // clamp_min
+ "dup v15.4s, w3\n" // clamp_max
+
+ // Apply the clamp_min bound
+ "fmax v16.4s, v16.4s, v14.4s\n"
+ "fmax v17.4s, v17.4s, v14.4s\n"
+ "fmax v18.4s, v18.4s, v14.4s\n"
+ "fmax v19.4s, v19.4s, v14.4s\n"
+ "fmax v20.4s, v20.4s, v14.4s\n"
+ "fmax v21.4s, v21.4s, v14.4s\n"
+ "fmax v22.4s, v22.4s, v14.4s\n"
+ "fmax v23.4s, v23.4s, v14.4s\n"
+ "fmax v24.4s, v24.4s, v14.4s\n"
+ "fmax v25.4s, v25.4s, v14.4s\n"
+ "fmax v26.4s, v26.4s, v14.4s\n"
+ "fmax v27.4s, v27.4s, v14.4s\n"
+ "fmax v28.4s, v28.4s, v14.4s\n"
+ "fmax v29.4s, v29.4s, v14.4s\n"
+ "fmax v30.4s, v30.4s, v14.4s\n"
+ "fmax v31.4s, v31.4s, v14.4s\n"
+
+ // Apply the clamp_max bound
+ "fmin v16.4s, v16.4s, v15.4s\n"
+ "fmin v17.4s, v17.4s, v15.4s\n"
+ "fmin v18.4s, v18.4s, v15.4s\n"
+ "fmin v19.4s, v19.4s, v15.4s\n"
+ "fmin v20.4s, v20.4s, v15.4s\n"
+ "fmin v21.4s, v21.4s, v15.4s\n"
+ "fmin v22.4s, v22.4s, v15.4s\n"
+ "fmin v23.4s, v23.4s, v15.4s\n"
+ "fmin v24.4s, v24.4s, v15.4s\n"
+ "fmin v25.4s, v25.4s, v15.4s\n"
+ "fmin v26.4s, v26.4s, v15.4s\n"
+ "fmin v27.4s, v27.4s, v15.4s\n"
+ "fmin v28.4s, v28.4s, v15.4s\n"
+ "fmin v29.4s, v29.4s, v15.4s\n"
+ "fmin v30.4s, v30.4s, v15.4s\n"
+ "fmin v31.4s, v31.4s, v15.4s\n"
+
+ // Compute how much of the 8x8 block of destination 8bit values that
+ // we have computed, fit in the destination matrix. Typically, all of
+ // it fits, but when the destination matrix shape is not a multiple
+ // of 8x8, there are some 8x8 blocks along the boundaries that do
+ // not fit entirely.
+ "sub w1, %w[dst_rows], %w[row]\n"
+ "sub w2, %w[dst_cols], %w[col]\n"
+ "mov w3, #8\n"
+ "cmp w1, #8\n"
+ // Compute w1 = how many rows of the 8x8 block fit
+ "csel w1, w1, w3, le\n"
+ "cmp w2, #8\n"
+ // Compute w2 = how many cols of the 8x8 block fit
+ "csel w2, w2, w3, le\n"
+
+ // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
+ "cmp w1, w3\n"
+ "ccmp w2, w3, 0, eq\n"
+ // Yes, all of the 8x8 block fits, go to fast path.
+ "beq 30f\n"
+ // Not all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write to dst_tmp_buf
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, #32\n"
+ "b 31f\n"
+ "30:\n"
+ // Yes, all of the 8x8 block fits.
+ // Set (x3 address, x4 stride) to write directly to destination matrix.
+ "mov x3, %[dst_ptr]\n"
+ "mov x4, x11\n"
+ "31:\n"
+
+ // Write our 8bit values to the destination described by
+ // (x3 address, x4 stride).
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ "str q16, [x3, #0]\n"
+ "str q17, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v16)
+ RUY_MAKE_ZERO(v17)
+ "str q18, [x3, #0]\n"
+ "str q19, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v18)
+ RUY_MAKE_ZERO(v19)
+ "str q20, [x3, #0]\n"
+ "str q21, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v20)
+ RUY_MAKE_ZERO(v21)
+ "str q22, [x3, #0]\n"
+ "str q23, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v22)
+ RUY_MAKE_ZERO(v23)
+ "str q24, [x3, #0]\n"
+ "str q25, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v24)
+ RUY_MAKE_ZERO(v25)
+ "str q26, [x3, #0]\n"
+ "str q27, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v26)
+ RUY_MAKE_ZERO(v27)
+ "str q28, [x3, #0]\n"
+ "str q29, [x3, #16]\n"
+ "add x3, x3, x4\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n")
+ RUY_MAKE_ZERO(v28)
+ RUY_MAKE_ZERO(v29)
+ "str q30, [x3, #0]\n"
+ "str q31, [x3, #16]\n"
+ RUY_MAKE_ZERO(v30)
+ RUY_MAKE_ZERO(v31)
+
+ // If all of the 8x8 block fits, we just finished writing it to the
+ // destination, so we skip the next part.
+ "beq 41f\n"
+ // Not all of the 8x8 block fits in the destination matrix. We just
+ // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
+ // it to copy into the destination matrix the part that fits.
+ "mov x3, %[dst_tmp_buf]\n"
+ "mov x4, %[dst_ptr]\n"
+ "mov w6, #0\n"
+ "50:\n"
+ RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
+ "mov w5, #0\n"
+ "51:\n"
+ "ldr w7, [x3, x5, lsl #2]\n"
+ "str w7, [x4, x5, lsl #2]\n"
+ "add w5, w5, #1\n"
+ "cmp w5, w1\n"
+ "blt 51b\n"
+ "add w6, w6, #1\n"
+ "add x3, x3, #32\n"
+ "add x4, x4, x11\n"
+ "cmp w6, w2\n"
+ "blt 50b\n"
+ "41:\n"
+ "add %[dst_ptr], %[dst_ptr], #32\n"
+ // At this point we have completely finished writing values to the
+ // destination matrix for the current block.
+
+ // Reload some params --- we had used x5 -- x7 for a few other things
+ // since the last time we had loaded them.
+ "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
+ "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
+ "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
+
+ // Move to the next block of the destination matrix, for the next iter
+ // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
+ // been updated earlier.
+ // Have we reached the end row?
+ "cmp %w[row], w7\n"
+ "beq 20f\n" // yes, end row.
+ // Not end row. Move to the next row.
+ "add %w[row], %w[row], #8\n"
+ "b 21f\n"
+ "20:\n"
+ // Was already at end row.
+ "mov %w[row], w6\n" // Move back to first row.
+ "add %w[col], %w[col], #8\n" // Move to the next column.
+ "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
+ "mov %[dst_ptr], %[dst_col_ptr]\n"
+ "21:\n"
+
+ // Main loop exit condition: have we hit the end column?
+ "cmp %w[col], w8\n"
+
+ // w1 is the number of levels of depth that remain to load
+ // LHS and RHS data for. Corresponding to the initial ld1 instructions
+ // above, this is currently depth - 1.
+ "sub w1, w12, #1\n"
+
+ "ble 1b\n"
+
+ // clang-format on
+
+ : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
+ : [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
+ [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf)
+ : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
+ "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
+ "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#undef RUY_OFFSET_BIAS
+#undef RUY_OFFSET_FLAGS
+#undef RUY_OFFSET_LHS_BASE_PTR
+#undef RUY_OFFSET_CLAMP_MIN
+#undef RUY_OFFSET_CLAMP_MAX
+#undef RUY_OFFSET_START_ROW
+#undef RUY_OFFSET_LAST_ROW
+#undef RUY_OFFSET_LAST_COL
+#undef RUY_OFFSET_LHS_STRIDE
+#undef RUY_OFFSET_RHS_STRIDE
+#undef RUY_OFFSET_DST_STRIDE
+#undef RUY_OFFSET_DEPTH
+#undef RUY_OFFSET_START_COL
+#undef RUY_OFFSET_RHS_BASE_PTR
+#undef RUY_OFFSET_DST_BASE_PTR
+
+#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2.cc
new file mode 100644
index 0000000..13fe22b
--- /dev/null
+++ b/ruy/kernel_avx2.cc
@@ -0,0 +1,1664 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+static constexpr int kAvx8bitBlockSize = 8;
+static constexpr int kAvx8bitInnerSize = 4;
+
+namespace {
+namespace intrin_utils {
+
+inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) {
+ switch (n) {
+ case 0:
+ return _mm256_setzero_si256();
+ case 1:
+ return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0),
+ _mm_setzero_si128());
+ case 2:
+ return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0),
+ _mm_setzero_si128());
+ case 3:
+ return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0),
+ _mm_setzero_si128());
+ case 4:
+ return _mm256_castsi128_si256(
+ _mm_loadu_si128(reinterpret_cast<__m128i const*>(src)));
+ case 5:
+ return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0);
+ case 6:
+ return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5],
+ 0, 0);
+ case 7:
+ return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5],
+ src[6], 0);
+ case 8:
+ return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src));
+ default:
+ RUY_DCHECK_LT(n, 9);
+ return _mm256_setzero_si256();
+ }
+}
+
+inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
+ const __m256 v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ __m256i shuffled_v;
+ if (residual_rows > 1) {
+ // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4
+ // in each 128-bit lane.
+ shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
+ }
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ dst[0] = _mm256_extract_epi8(v, 0);
+ break;
+ case 2:
+ _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ break;
+ case 3: {
+ __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0);
+ _mm_storeu_si16(dst, trailing_packed);
+ dst[2] = _mm_extract_epi8(trailing_packed, 2);
+ break;
+ }
+ case 4:
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ break;
+ case 5:
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ dst[4] = _mm256_extract_epi8(shuffled_v, 16);
+ break;
+ case 6:
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+ break;
+ case 7: {
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1);
+ _mm_storeu_si16(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi8(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+}
+
+inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
+ const __m256 v) {
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(
+ reinterpret_cast<std::uint8_t*>(dst), residual_rows, v);
+}
+
+inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) {
+ // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
+ const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
+ const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
+ _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+}
+
+inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
+ const __m256 v) {
+ // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
+ // truncating each 16-bit integer.
+ const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
+ __m256i shuffled_v;
+ __m128i shuffled_v_low;
+ if (residual_rows > 1) {
+ shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
+ shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0);
+ } else {
+ shuffled_v_low = _mm256_extracti128_si256(v, 0);
+ }
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ _mm_storeu_si16(dst, shuffled_v_low);
+ break;
+ case 2:
+ _mm_storeu_si32(dst, shuffled_v_low);
+ break;
+ case 3: {
+ _mm_storeu_si32(dst, shuffled_v_low);
+ dst[2] = _mm_extract_epi16(shuffled_v_low, 2);
+ break;
+ }
+ case 4:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ break;
+ case 5:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ dst[4] = _mm256_extract_epi16(shuffled_v, 8);
+ break;
+ case 6:
+ _mm_storeu_si64(dst, shuffled_v_low);
+ _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+ break;
+ case 7: {
+ _mm_storeu_si64(dst, shuffled_v_low);
+ __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1);
+ _mm_storeu_si32(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi16(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) {
+ // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
+ // truncating each 16-bit integer.
+ const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
+ const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm);
+ _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0));
+ _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1));
+}
+
+inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
+ const __m256 v) {
+ const __m128i v_low = _mm256_extracti128_si256(v, 0);
+ switch (residual_rows) {
+ case 0:
+ break;
+ case 1:
+ _mm_storeu_si32(dst, v_low);
+ break;
+ case 2:
+ _mm_storeu_si64(dst, v_low);
+ break;
+ case 3: {
+ __m128i trailing_packed = v_low;
+ _mm_storeu_si64(dst, trailing_packed);
+ dst[2] = _mm_extract_epi32(trailing_packed, 2);
+ break;
+ }
+ case 4:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ break;
+ case 5:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ dst[4] = _mm256_extract_epi32(v, 4);
+ break;
+ case 6:
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1));
+ break;
+ case 7: {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
+ __m128i trailing_packed = _mm256_extracti128_si256(v, 1);
+ _mm_storeu_si64(dst + 4, trailing_packed);
+ dst[6] = _mm_extract_epi32(trailing_packed, 2);
+ break;
+ }
+ case 8:
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+ break;
+ default:
+ RUY_DCHECK_LE(residual_rows, 8);
+ break;
+ }
+}
+
+inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
+}
+
+inline float mm256_get1_ps(const __m256 a, int i) {
+ __m256i ai = _mm256_castps_si256(a);
+ int float_val_as_int;
+ switch (i) {
+ case 0:
+ float_val_as_int = _mm256_extract_epi32(ai, 0);
+ break;
+ case 1:
+ float_val_as_int = _mm256_extract_epi32(ai, 1);
+ break;
+ case 2:
+ float_val_as_int = _mm256_extract_epi32(ai, 2);
+ break;
+ case 3:
+ float_val_as_int = _mm256_extract_epi32(ai, 3);
+ break;
+ case 4:
+ float_val_as_int = _mm256_extract_epi32(ai, 4);
+ break;
+ case 5:
+ float_val_as_int = _mm256_extract_epi32(ai, 5);
+ break;
+ case 6:
+ float_val_as_int = _mm256_extract_epi32(ai, 6);
+ break;
+ case 7:
+ float_val_as_int = _mm256_extract_epi32(ai, 7);
+ break;
+ default:
+ RUY_DCHECK_LT(i, 8);
+ return .0f;
+ }
+ return reinterpret_cast<float&>(float_val_as_int);
+}
+
+inline __m256 mm256_n_loadu_ps(int i, const float* src) {
+ switch (i) {
+ case 0:
+ return _mm256_setzero_ps();
+ case 1:
+ return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f),
+ _mm_setzero_ps());
+ case 2:
+ return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f),
+ _mm_setzero_ps());
+ case 3:
+ return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f),
+ _mm_setzero_ps());
+ case 4:
+ return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]),
+ _mm_setzero_ps());
+ case 5:
+ return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f,
+ .0f);
+ case 6:
+ return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f,
+ .0f);
+ case 7:
+ return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5],
+ src[6], .0f);
+ case 8:
+ return _mm256_loadu_ps(src);
+ default:
+ RUY_DCHECK_LT(i, 9);
+ return _mm256_setzero_ps();
+ }
+}
+
+inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst[i] = intrin_utils::mm256_get1_ps(v, i);
+ }
+}
+} // namespace intrin_utils
+} // namespace
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2 8-bit");
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ std::int32_t dst_stride;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx = _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+ __m256i accum_data_v1;
+ __m256i accum_data_v2;
+ __m256i accum_data_v3;
+ __m256i accum_data_v4;
+ __m256i accum_data_v5;
+ __m256i accum_data_v6;
+ __m256i accum_data_v7;
+
+ // Initialize with bias.
+ __m256i initial_accum_data =
+ intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = _mm256_mullo_epi32(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
+ initial_accum_data =
+ _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth) {
+ initial_accum_data = _mm256_add_epi32(initial_accum_data,
+ _mm256_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
+ accum_data_v1 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
+ accum_data_v2 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
+ accum_data_v3 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
+ accum_data_v4 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
+ accum_data_v5 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
+ accum_data_v6 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
+ accum_data_v7 = _mm256_sub_epi32(
+ initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m256i rhs_data_8bit =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ std::int32_t rhs_data[16];
+ const __m128i rhs_data_bottom_lane =
+ _mm256_castsi256_si128(rhs_data_8bit);
+ const __m128i rhs_data_top_lane =
+ _mm256_extracti128_si256(rhs_data_8bit, 1);
+ const __m256i rhs_16_bit_dup_low =
+ _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m256i rhs_16_bit_dup_high =
+ _mm256_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
+ rhs_16_bit_dup_low);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
+ rhs_16_bit_dup_high);
+
+ // NOTE: There may be opportunities for permuting the data in the
+ // packing code instead of here.
+ const __m256i lhs_data_split =
+ _mm256_shuffle_epi8(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+ // Accumulate for column 0.
+ {
+ const std::int32_t low_rhs_value = rhs_data[0];
+ const std::int32_t high_rhs_value = rhs_data[1];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 1.
+ {
+ const std::int32_t low_rhs_value = rhs_data[2];
+ const std::int32_t high_rhs_value = rhs_data[3];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v1 = _mm256_add_epi32(
+ accum_data_v1,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v1 = _mm256_add_epi32(
+ accum_data_v1,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 2.
+ {
+ const std::int32_t low_rhs_value = rhs_data[4];
+ const std::int32_t high_rhs_value = rhs_data[5];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v2 = _mm256_add_epi32(
+ accum_data_v2,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v2 = _mm256_add_epi32(
+ accum_data_v2,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 3.
+ {
+ const std::int32_t low_rhs_value = rhs_data[6];
+ const std::int32_t high_rhs_value = rhs_data[7];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v3 = _mm256_add_epi32(
+ accum_data_v3,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v3 = _mm256_add_epi32(
+ accum_data_v3,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 4.
+ {
+ const std::int32_t low_rhs_value = rhs_data[8];
+ const std::int32_t high_rhs_value = rhs_data[9];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v4 = _mm256_add_epi32(
+ accum_data_v4,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v4 = _mm256_add_epi32(
+ accum_data_v4,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 5.
+ {
+ const std::int32_t low_rhs_value = rhs_data[10];
+ const std::int32_t high_rhs_value = rhs_data[11];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v5 = _mm256_add_epi32(
+ accum_data_v5,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v5 = _mm256_add_epi32(
+ accum_data_v5,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 6.
+ {
+ const std::int32_t low_rhs_value = rhs_data[12];
+ const std::int32_t high_rhs_value = rhs_data[13];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v6 = _mm256_add_epi32(
+ accum_data_v6,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v6 = _mm256_add_epi32(
+ accum_data_v6,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+ // Accumulate for column 7.
+ {
+ const std::int32_t low_rhs_value = rhs_data[14];
+ const std::int32_t high_rhs_value = rhs_data[15];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v7 = _mm256_add_epi32(
+ accum_data_v7,
+ _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v7 = _mm256_add_epi32(
+ accum_data_v7,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ }
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, &params.multiplier_fixedpoint[row]);
+ e_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, &params.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m256i m_64bit_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
+ const __m256i m_64bit_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
+ const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
+ const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
+ const __m256i final_right_shift =
+ _mm256_add_epi32(right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(final_right_shift, 1));
+ // Really we want 0x100000000, but use half to avoid overflowing.
+ const __m256i convert_to_signed_halved =
+ _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift);
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset = _mm256_add_epi32(
+ convert_to_signed_halved, convert_to_signed_halved);
+
+ const __m256i offset_vector =
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m256i offset_vector_low = _mm256_add_epi64(
+ _mm256_sllv_epi64(offset_vector,
+ _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(right_shift, 0))),
+ convert_to_unsigned_64);
+ const __m256i offset_vector_high = _mm256_add_epi64(
+ _mm256_sllv_epi64(offset_vector,
+ _mm256_cvtepi32_epi64(
+ _mm256_extracti128_si256(right_shift, 1))),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ const __m256i dst_zero_point =
+ _mm256_set1_epi32(params.dst_zero_point);
+ // The post-scaling offset is subtracted later, so this has the effect
+ // of adding the zero point.
+ post_scaling_offset =
+ _mm256_sub_epi32(post_scaling_offset, dst_zero_point);
+ }
+
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+ const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+
+ // We cannot do
+ //
+ // scaled_v_low =
+ // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
+ // scaled_v_high =
+ // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
+ //
+ // since this instruction is not in AVX2. Instead we use
+ // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
+ // offsets before (convert_to_unsigned_64) and after
+ // (convert_to_signed_halved).
+ //
+ // The overall process is, for 64-bit scaled accumulator:
+ // unsigned_accum = signed_accum + 1 << 63;
+ // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
+ // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
+
+ // There are various ways to repack the results, in the absence of
+ // _mm256_cvtepi64_epi32() or anything like it.
+ // A.
+ // accum_data_v[j] =
+ // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
+ // _mm256_extract_epi32(scaled_v_high, 4),
+ // _mm256_extract_epi32(scaled_v_high, 2),
+ // _mm256_extract_epi32(scaled_v_high, 0),
+ // _mm256_extract_epi32(scaled_v_low, 6),
+ // _mm256_extract_epi32(scaled_v_low, 4),
+ // _mm256_extract_epi32(scaled_v_low, 2),
+ // _mm256_extract_epi32(scaled_v_low, 0));
+ // B.
+ // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
+ // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
+ // accum_data_v[j] =
+ // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
+ // _mm256_extract_epi64(scaled_v_high, 0),
+ // _mm256_extract_epi64(scaled_v_low, 2),
+ // _mm256_extract_epi64(scaled_v_low, 0));
+ // C.
+ // scaled_v_low =
+ // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
+ // scaled_v_high =
+ // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
+ // accum_data_v[j] =
+ // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
+ //
+ // However, we choose the following because it uses two lighter
+ // instructions. The permutation does have a longer latency, but this
+ // loop can be unrolled.
+ // D.
+ // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ // __m256i results =
+ // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ // results = _mm256_permutevar8x32_epi32(results, repack_perm);
+ // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset);
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results =
+ _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ __m256i accum_data_v[kAvx8bitBlockSize];
+ if (!store_full_block) {
+ accum_data_v[0] = accum_data_v0;
+ accum_data_v[1] = accum_data_v1;
+ accum_data_v[2] = accum_data_v2;
+ accum_data_v[3] = accum_data_v3;
+ accum_data_v[4] = accum_data_v4;
+ accum_data_v[5] = accum_data_v5;
+ accum_data_v[6] = accum_data_v6;
+ accum_data_v[7] = accum_data_v7;
+ }
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride],
+ accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256 result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256 result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ if (store_full_block) {
+ accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
+ accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
+ accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
+ accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
+ accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
+ accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
+ accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
+ accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride],
+ accum_data_v1);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m256 result = accum_data_v[j];
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows,
+ result);
+ tmp_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride],
+ accum_data_v2);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride],
+ accum_data_v3);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride],
+ accum_data_v4);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride],
+ accum_data_v5);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride],
+ accum_data_v6);
+ intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride],
+ accum_data_v7);
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows,
+ accum_data_v[j]);
+ dst_block_ptr += dst_stride;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx =
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+
+ // Initialize with bias.
+ __m256i initial_accum_data =
+ intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = _mm256_mullo_epi32(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
+ initial_accum_data =
+ _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth) {
+ initial_accum_data = _mm256_add_epi32(initial_accum_data,
+ _mm256_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
+ _mm256_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // NOTE: There may be opportunities for permuting the data in the packing
+ // code instead of here.
+ const __m256i lhs_data_split =
+ _mm256_shuffle_epi8(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+ // Accumulate for column 0.
+ const std::int32_t low_rhs_value = rhs_data[0];
+ const std::int32_t high_rhs_value = rhs_data[1];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, &params.multiplier_fixedpoint[row]);
+ e_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, &params.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m256i m_64bit_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
+ const __m256i m_64bit_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
+ const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
+ const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
+ const __m256i final_right_shift =
+ _mm256_add_epi32(right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
+ // Really we want 0x100000000, but use half to avoid overflowing.
+ const __m256i convert_to_signed_halved =
+ _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift);
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset =
+ _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved);
+
+ const __m256i offset_vector =
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m256i offset_vector_low = _mm256_add_epi64(
+ _mm256_sllv_epi64(
+ offset_vector,
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))),
+ convert_to_unsigned_64);
+ const __m256i offset_vector_high = _mm256_add_epi64(
+ _mm256_sllv_epi64(
+ offset_vector,
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point);
+ // The post-scaling offset is subtracted later, so this has the effect
+ // of adding the zero point.
+ post_scaling_offset =
+ _mm256_sub_epi32(post_scaling_offset, dst_zero_point);
+ }
+
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+ const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+
+ // See GEMM version for details of this process.
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows,
+ accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+} // NOLINT(readability/fn_size)
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2 float");
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ const std::int64_t dst_stride = params.dst_stride >> 2;
+ const std::int64_t rhs_stride = params.rhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+ const int end_col = std::min(params.dst_cols, params.last_col + 8);
+ //
+ const float* adj_rhs_col_ptr =
+ params.rhs_base_ptr - params.start_col * rhs_stride;
+ float* adj_dst_col_ptr =
+ params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+
+ int col = params.start_col;
+ // Loop through cols by float block size, leaving incomplete remainder
+ for (; col <= end_col - 8; col += 8) {
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __m256 initial_accum_data =
+ intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ // In this version RHS values are loaded individually rather than first
+ // loading together and then extract with broadcasting. This is because
+ // AVX flavours and instrinsics and compilers in combination do not
+ // handle this pattern of extraction very well.
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ if (residual_rows == 8) {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ _mm256_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ } else {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ }
+ } // End row-block loop.
+ } // End col-block loop.
+
+ if (col < end_col) {
+ // Remaining cols in [0, float block size).
+ RUY_DCHECK_GE(end_col - col, 0);
+ RUY_DCHECK_LT(end_col - col, 8);
+
+ __m256 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+ const int residual_cols = std::min(end_col - col, 8);
+
+ for (int row = params.start_row; row < end_row; row += 8) {
+ const int residual_rows = std::min(end_row - row, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __m256 initial_accum_data =
+ intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + j * dst_stride;
+ accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows,
+ accum_data_v[j]);
+ }
+ } // End row-block loop.
+ } // End col-block terminal conditional.
+}
+
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kAvx2 float GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+
+ __m256 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 8; row += 8) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ int d = 0;
+ for (; d <= params.depth - 4; d += 4) {
+ const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
+ const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v);
+ const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
+ const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v);
+
+ const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
+ const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v);
+ const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
+ const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v);
+ lhs_ptr += 32; // Loaded 8 * 4 floats.
+ rhs_ptr += 32;
+ }
+ for (; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ _mm256_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v);
+ } // End handling of residual rows.
+}
+
+#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
new file mode 100644
index 0000000..5e771a5
--- /dev/null
+++ b/ruy/kernel_avx512.cc
@@ -0,0 +1,1820 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 8-bit");
+
+ std::int32_t dst_stride;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col; col += 16) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[16];
+ if (has_rhs_sums_offsets) {
+ const __m512i rhs_sums_offset_v =
+ _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+ _mm512_loadu_epi32(&params.rhs_sums[col]));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row; row += 16) {
+ const int residual_rows = std::min(params.dst_rows - row, 16);
+ const int residual_cols = std::min(params.dst_cols - col, 16);
+
+ __m512i accum_data_v0;
+ __m512i accum_data_v1;
+ __m512i accum_data_v2;
+ __m512i accum_data_v3;
+ __m512i accum_data_v4;
+ __m512i accum_data_v5;
+ __m512i accum_data_v6;
+ __m512i accum_data_v7;
+ __m512i accum_data_v8;
+ __m512i accum_data_v9;
+ __m512i accum_data_va;
+ __m512i accum_data_vb;
+ __m512i accum_data_vc;
+ __m512i accum_data_vd;
+ __m512i accum_data_ve;
+ __m512i accum_data_vf;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m512i lhs_sums_offset =
+ _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+ _mm512_loadu_epi32(&params.lhs_sums[row]));
+ initial_accum_data =
+ _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth != 0) {
+ initial_accum_data = _mm512_add_epi32(initial_accum_data,
+ _mm512_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
+ accum_data_v1 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
+ accum_data_v2 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
+ accum_data_v3 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
+ accum_data_v4 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
+ accum_data_v5 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
+ accum_data_v6 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
+ accum_data_v7 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
+ accum_data_v8 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
+ accum_data_v9 = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
+ accum_data_va = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
+ accum_data_vb = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
+ accum_data_vc = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
+ accum_data_vd = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
+ accum_data_ve = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
+ accum_data_vf = _mm512_sub_epi32(
+ initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ accum_data_v8 = initial_accum_data;
+ accum_data_v9 = initial_accum_data;
+ accum_data_va = initial_accum_data;
+ accum_data_vb = initial_accum_data;
+ accum_data_vc = initial_accum_data;
+ accum_data_vd = initial_accum_data;
+ accum_data_ve = initial_accum_data;
+ accum_data_vf = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += 4) {
+ const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
+ __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ std::int32_t rhs_data[32];
+ const __m256i rhs_data_bottom_lane =
+ _mm512_castsi512_si256(rhs_data_8bit);
+ const __m256i rhs_data_top_lane =
+ _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
+ const __m512i rhs_16_bit_dup_low =
+ _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
+ rhs_16_bit_dup_low);
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
+ rhs_16_bit_dup_high);
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_low =
+ _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+ // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+ _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+ // Process column 0.
+ {
+ __m512i accum_v = accum_data_v0;
+ constexpr int index = 0;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v0 = accum_v;
+ }
+ // Process column 1.
+ {
+ __m512i accum_v = accum_data_v1;
+ constexpr int index = 2;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v1 = accum_v;
+ }
+ // Process column 2.
+ {
+ __m512i accum_v = accum_data_v2;
+ constexpr int index = 4;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v2 = accum_v;
+ }
+ // Process column 3.
+ {
+ __m512i accum_v = accum_data_v3;
+ constexpr int index = 6;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v3 = accum_v;
+ }
+ // Process column 4.
+ {
+ __m512i accum_v = accum_data_v4;
+ constexpr int index = 8;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v4 = accum_v;
+ }
+ // Process column 5.
+ {
+ __m512i accum_v = accum_data_v5;
+ constexpr int index = 10;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v5 = accum_v;
+ }
+ // Process column 6.
+ {
+ __m512i accum_v = accum_data_v6;
+ constexpr int index = 12;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v6 = accum_v;
+ }
+ // Process column 7.
+ {
+ __m512i accum_v = accum_data_v7;
+ constexpr int index = 14;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v7 = accum_v;
+ }
+ // Process column 8.
+ {
+ __m512i accum_v = accum_data_v8;
+ constexpr int index = 16;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v8 = accum_v;
+ }
+ // Process column 9.
+ {
+ __m512i accum_v = accum_data_v9;
+ constexpr int index = 18;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v9 = accum_v;
+ }
+ // Process column 10.
+ {
+ __m512i accum_v = accum_data_va;
+ constexpr int index = 20;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_va = accum_v;
+ }
+ // Process column 11.
+ {
+ __m512i accum_v = accum_data_vb;
+ constexpr int index = 22;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_vb = accum_v;
+ }
+ // Process column 12.
+ {
+ __m512i accum_v = accum_data_vc;
+ constexpr int index = 24;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_vc = accum_v;
+ }
+ // Process column 13.
+ {
+ __m512i accum_v = accum_data_vd;
+ constexpr int index = 26;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_vd = accum_v;
+ }
+ // Process column 14.
+ {
+ __m512i accum_v = accum_data_ve;
+ constexpr int index = 28;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_ve = accum_v;
+ }
+ // Process column 15.
+ {
+ __m512i accum_v = accum_data_vf;
+ constexpr int index = 30;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_vf = accum_v;
+ }
+
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 4;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m512i m_vector;
+ __m512i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = _mm512_maskz_loadu_epi32(
+ row_mask, &params.multiplier_fixedpoint[row]);
+ e_vector = _mm512_maskz_loadu_epi32(row_mask,
+ &params.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m512i m_64bit_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+ const __m512i m_64bit_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+ const __m512i zero_vector = _mm512_setzero_epi32();
+ const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+ const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+ const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+ const __m512i final_right_shift =
+ _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+ const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 0));
+ const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+ const __m512i offset_vector =
+ _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m512i offset_vector_low = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
+ const __m512i offset_vector_high = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+ // Shift and round column 0.
+ {
+ accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v0, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v0, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v0 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v0 = _mm512_inserti32x8(
+ accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 1.
+ {
+ accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v1, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v1, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v1 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v1 = _mm512_inserti32x8(
+ accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 2.
+ {
+ accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v2, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v2, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v2 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v2 = _mm512_inserti32x8(
+ accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 3.
+ {
+ accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v3, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v3, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v3 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v3 = _mm512_inserti32x8(
+ accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 4.
+ {
+ accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v4, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v4, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v4 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v4 = _mm512_inserti32x8(
+ accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 5.
+ {
+ accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v5, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v5, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v5 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v5 = _mm512_inserti32x8(
+ accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 6.
+ {
+ accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v6, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v6, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v6 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v6 = _mm512_inserti32x8(
+ accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 7.
+ {
+ accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v7, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v7, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v7 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v7 = _mm512_inserti32x8(
+ accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 8.
+ {
+ accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v8, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v8, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v8 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v8 = _mm512_inserti32x8(
+ accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 9.
+ {
+ accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v9, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_v9, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v9 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v9 = _mm512_inserti32x8(
+ accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 10.
+ {
+ accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_va, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_va, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_va =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_va = _mm512_inserti32x8(
+ accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 11.
+ {
+ accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vb, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vb, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_vb =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_vb = _mm512_inserti32x8(
+ accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 12.
+ {
+ accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vc, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vc, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_vc =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_vc = _mm512_inserti32x8(
+ accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 13.
+ {
+ accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vd, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vd, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_vd =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_vd = _mm512_inserti32x8(
+ accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 14.
+ {
+ accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_ve, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_ve, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_ve =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_ve = _mm512_inserti32x8(
+ accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+ // Shift and round column 15.
+ {
+ accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vf, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high =
+ _mm512_mul_epi32(_mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(accum_data_vf, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_vf =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_vf = _mm512_inserti32x8(
+ accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ }
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+
+ if (params.dst_zero_point != 0) {
+ __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+ accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
+ accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
+ accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
+ accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
+ accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
+ accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
+ accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
+ accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
+ accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
+ accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
+ accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
+ accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
+ accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
+ accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
+ accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
+ accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
+ }
+ }
+
+ const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+ const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+
+ const bool store_full_block =
+ (residual_rows == 16) && (residual_cols == 16);
+
+ __m512i accum_data_v[16];
+
+ // In most cases we would make this conditional on (!store_full_block) and
+ // unwind the clamp-and-store loop, but the benefit appears small.
+ {
+ accum_data_v[0] = accum_data_v0;
+ accum_data_v[1] = accum_data_v1;
+ accum_data_v[2] = accum_data_v2;
+ accum_data_v[3] = accum_data_v3;
+ accum_data_v[4] = accum_data_v4;
+ accum_data_v[5] = accum_data_v5;
+ accum_data_v[6] = accum_data_v6;
+ accum_data_v[7] = accum_data_v7;
+ accum_data_v[8] = accum_data_v8;
+ accum_data_v[9] = accum_data_v9;
+ accum_data_v[10] = accum_data_va;
+ accum_data_v[11] = accum_data_vb;
+ accum_data_v[12] = accum_data_vc;
+ accum_data_v[13] = accum_data_vd;
+ accum_data_v[14] = accum_data_ve;
+ accum_data_v[15] = accum_data_vf;
+ }
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < 16; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_storeu_epi8(tmp_ptr + j * block_col_offset,
+ _mm512_cvtepi32_epi8(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi8(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_storeu_epi8(tmp_ptr + j * block_col_offset,
+ _mm512_cvtepi32_epi8(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi8(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ const int block_col_offset = dst_stride;
+ if (store_full_block) {
+ for (int j = 0; j < 16; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_storeu_epi16(tmp_ptr + j * block_col_offset,
+ _mm512_cvtepi32_epi16(result));
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ __m512i result = accum_data_v[j];
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
+ _mm512_cvtepi32_epi16(result));
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < 16; ++j) {
+ _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]);
+ }
+ } else {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
+ accum_data_v[j]);
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += 16 * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ 16 * params.dst_stride);
+ rhs_col_ptr += 16 * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ std::int32_t dst_stride;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[16];
+ if (has_rhs_sums_offsets) {
+ const __m512i rhs_sums_offset_v =
+ _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+ _mm512_loadu_epi32(&params.rhs_sums[0]));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row; row += 16) {
+ const int residual_rows = std::min(params.dst_rows - row, 16);
+
+ __m512i accum_data_v0;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m512i lhs_sums_offset =
+ _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+ _mm512_loadu_epi32(&params.lhs_sums[row]));
+ initial_accum_data =
+ _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth != 0) {
+ initial_accum_data = _mm512_add_epi32(initial_accum_data,
+ _mm512_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
+ _mm512_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += 4) {
+ const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
+ const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_low =
+ _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+ // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+ _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+ // Process column 0.
+ __m512i accum_v = accum_data_v0;
+ constexpr int index = 0;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v0 = accum_v;
+
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 4;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m512i m_vector;
+ __m512i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = _mm512_maskz_loadu_epi32(row_mask,
+ &params.multiplier_fixedpoint[row]);
+ e_vector = _mm512_maskz_loadu_epi32(row_mask,
+ &params.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m512i m_64bit_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+ const __m512i m_64bit_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+ const __m512i zero_vector = _mm512_setzero_epi32();
+ const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+ const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+ const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+ const __m512i final_right_shift =
+ _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+ const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 0));
+ const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+ const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m512i offset_vector_low = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
+ const __m512i offset_vector_high = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+ // Shift and round column 0.
+ accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v0 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v0 = _mm512_inserti32x8(
+ accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+
+ if (params.dst_zero_point != 0) {
+ __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+ accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
+ }
+ }
+
+ const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+ const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
+ _mm512_cvtepi32_epi16(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += 16 * params.lhs_stride;
+ } // End row-block loop.
+} // NOLINT(readability/fn_size)
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 float");
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ const std::int64_t dst_stride = params.dst_stride >> 2;
+ const std::int64_t rhs_stride = params.rhs_stride >> 2;
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ const int end_row = std::min(params.dst_rows, params.last_row + 16);
+ const int end_col = std::min(params.dst_cols, params.last_col + 16);
+
+ const float* adj_rhs_col_ptr =
+ params.rhs_base_ptr - params.start_col * rhs_stride;
+ float* adj_dst_col_ptr =
+ params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+ const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+
+ int col = params.start_col;
+ for (; col <= end_col - 16; col += 16) {
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ int row = params.start_row;
+ for (; row <= end_row - 16; row += 16) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
+
+ // Process block in two halves, split by columns.
+ {
+ constexpr int mmm = 0;
+
+ __m512 accum_data_v0 = initial_accum_data;
+ __m512 accum_data_v1 = initial_accum_data;
+ __m512 accum_data_v2 = initial_accum_data;
+ __m512 accum_data_v3 = initial_accum_data;
+ __m512 accum_data_v4 = initial_accum_data;
+ __m512 accum_data_v5 = initial_accum_data;
+ __m512 accum_data_v6 = initial_accum_data;
+ __m512 accum_data_v7 = initial_accum_data;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ // In this version RHS values are loaded individually rather than
+ // first loading together and then extract with broadcasting. This is
+ // because AVX flavours and instrinsics and compilers in combination
+ // do not handle this pattern of extraction very well.
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+ }
+ }
+ } // Inner half-block loop, unrolled, first iteration.
+ {
+ constexpr int mmm = 1;
+
+ __m512 accum_data_v0 = initial_accum_data;
+ __m512 accum_data_v1 = initial_accum_data;
+ __m512 accum_data_v2 = initial_accum_data;
+ __m512 accum_data_v3 = initial_accum_data;
+ __m512 accum_data_v4 = initial_accum_data;
+ __m512 accum_data_v5 = initial_accum_data;
+ __m512 accum_data_v6 = initial_accum_data;
+ __m512 accum_data_v7 = initial_accum_data;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
+ }
+ }
+ } // Inner half-block loop, unrolled, second iteration.
+ } // End row-block loop.
+
+ // The unrolling within this conditional may be somewhat pointless. It
+ // depends on the kinds of models.
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ const __m512 initial_accum_data =
+ _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+ // Process block in two halves, split by columns.
+ for (int mmm = 0; mmm < 2; ++mmm) {
+ __m512 accum_data_v0 = initial_accum_data;
+ __m512 accum_data_v1 = initial_accum_data;
+ __m512 accum_data_v2 = initial_accum_data;
+ __m512 accum_data_v3 = initial_accum_data;
+ __m512 accum_data_v4 = initial_accum_data;
+ __m512 accum_data_v5 = initial_accum_data;
+ __m512 accum_data_v6 = initial_accum_data;
+ __m512 accum_data_v7 = initial_accum_data;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < (params.depth - 1); ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ }
+ {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+ {
+ const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]);
+ accum_data_v0 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
+ const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]);
+ accum_data_v1 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
+ const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]);
+ accum_data_v2 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
+ const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]);
+ accum_data_v3 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
+ const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]);
+ accum_data_v4 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
+ const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]);
+ accum_data_v5 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
+ const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]);
+ accum_data_v6 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
+ const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]);
+ accum_data_v7 =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
+ }
+ {
+ float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
+ accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
+ accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
+ accum_data_v0);
+ accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
+ accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
+ accum_data_v1);
+ accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
+ accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
+ accum_data_v2);
+ accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
+ accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
+ accum_data_v3);
+ accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
+ accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
+ accum_data_v4);
+ accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
+ accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
+ accum_data_v5);
+ accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
+ accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
+ accum_data_v6);
+ accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
+ accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
+ accum_data_v7);
+ }
+ }
+ } // Inner half-block loop.
+ } // Residual rows, main col-block loop.
+ } // End col-block loop.
+
+ if (col < end_col) {
+ RUY_DCHECK_GE(end_col - col, 0);
+ RUY_DCHECK_LT(end_col - col, 16);
+
+ __m512 accum_data_v[8];
+
+ const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
+ float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
+
+ for (int row = params.start_row; row < end_row; row += 16) {
+ const int residual_rows = std::min(end_row - row, 16);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ const __m512 initial_accum_data =
+ _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+ // Process block in two halves, split by columns.
+ for (int mmm = 0; mmm < 2; ++mmm) {
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ for (int j = 0; j < 8; ++j) {
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
+ accum_data_v[j] =
+ _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
+ }
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
+
+ if (residual_rows == 16) {
+ if (residual_cols == 8) {
+ for (int j = 0; j < 8; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_storeu_ps(block_ptr, accum_data_v[j]);
+ }
+ }
+ } else {
+ for (int j = 0; j < residual_cols; ++j) {
+ float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
+ accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
+ accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
+ _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
+ }
+ }
+ } // Inner half-block loop.
+ } // End row-block loop.
+ } // Residual cols.
+}
+
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ const int end_row = std::min(params.dst_rows, params.last_row + 16);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+ const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+
+ __m512 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 16; row += 16) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm512_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 16);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
+ } // End handling of residual rows.
+}
+
+#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_avxvnni.cc b/ruy/kernel_avxvnni.cc
new file mode 100644
index 0000000..4513b20
--- /dev/null
+++ b/ruy/kernel_avxvnni.cc
@@ -0,0 +1,435 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+static constexpr int kAvxFloatBlockSize = 16;
+static constexpr int kAvx8bitBlockSize = 16;
+static constexpr int kAvx8bitInnerSize = 4;
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)");
+
+ std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize];
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ // Initialize with bias.
+ std::int32_t initial_accum_data[kAvx8bitBlockSize];
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ initial_accum_data[i] = 0;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x];
+ rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x];
+ }
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x];
+ }
+ }
+ }
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.rhs_zero_point * params.lhs_sums[row + i];
+ }
+ }
+ }
+ if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.lhs_zero_point * params.rhs_sums[col + j];
+ }
+ }
+ }
+ if (params.lhs_zero_point && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.prod_zp_depth;
+ }
+ }
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ std::int32_t m_vector[kAvx8bitBlockSize];
+ std::int32_t e_vector[kAvx8bitBlockSize];
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ int i = 0;
+ for (; i < residual_rows; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[row + i];
+ e_vector[i] = params.multiplier_exponent[row + i];
+ }
+ for (; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = m_vector[0];
+ e_vector[i] = e_vector[0];
+ }
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[i];
+ e_vector[i] = params.multiplier_exponent[i];
+ }
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = MultiplyByQuantizedMultiplier(
+ accum_data[j][i], m_vector[i], e_vector[i]);
+ }
+ }
+
+ if (params.dst_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.dst_zero_point;
+ }
+ }
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<std::int32_t>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<std::int32_t>(accum_data[j][i], params.clamp_min);
+ }
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr =
+ store_full_block
+ ? static_cast<std::int8_t*>(dst_ptr)
+ : const_cast<std::int8_t*>(
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride / sizeof(std::int8_t)
+ : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::int8_t* block_ptr =
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::int8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = store_full_block
+ ? static_cast<std::uint8_t*>(dst_ptr)
+ : const_cast<std::uint8_t*>(
+ reinterpret_cast<const std::uint8_t*>(
+ params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::uint8_t* block_ptr =
+ reinterpret_cast<const std::uint8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::uint8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ if (store_full_block) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int16_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int16_t* tmp_ptr = const_cast<std::int16_t*>(
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf));
+ const int block_col_offset = kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ const std::int16_t* block_ptr =
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf);
+ std::int16_t* dst_block_ptr = static_cast<std::int16_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = block_ptr[i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int16_t);
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int32_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = accum_data[j][i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int32_t);
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) {
+ profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)");
+
+ float lhs_data[kAvxFloatBlockSize];
+ float rhs_data[kAvxFloatBlockSize];
+ float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize];
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ const float* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvxFloatBlockSize) {
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ const float* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvxFloatBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvxFloatBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvxFloatBlockSize);
+
+ // Initialize with bias.
+ float initial_accum_data[kAvxFloatBlockSize];
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ initial_accum_data[i] = 0.0f;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ lhs_data[i] = lhs_ptr[i];
+ rhs_data[i] = rhs_ptr[i];
+ }
+
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] += lhs_data[i] * rhs_data[j];
+ }
+ }
+
+ lhs_ptr += kAvxFloatBlockSize;
+ rhs_ptr += kAvxFloatBlockSize;
+ }
+
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<float>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<float>(accum_data[j][i], params.clamp_min);
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvxFloatBlockSize) &&
+ (residual_cols == kAvxFloatBlockSize);
+
+ {
+ float* block_ptr =
+ store_full_block ? dst_ptr : const_cast<float*>(params.dst_tmp_buf);
+ const int block_col_offset = store_full_block
+ ? params.dst_stride / sizeof(float)
+ : kAvxFloatBlockSize;
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ block_ptr[i] = accum_data[j][i];
+ }
+ block_ptr += block_col_offset;
+ }
+ }
+ if (!store_full_block) {
+ const float* block_ptr = params.dst_tmp_buf;
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i];
+ }
+ block_ptr += kAvxFloatBlockSize;
+ }
+ }
+
+ lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float);
+ dst_ptr += kAvxFloatBlockSize;
+ } // End row-block loop.
+
+ dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float);
+ rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float);
+ } // End col-block loop.
+}
+
+#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
new file mode 100644
index 0000000..0cd123f
--- /dev/null
+++ b/ruy/kernel_common.h
@@ -0,0 +1,481 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/spec.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+struct Kernel {};
+
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
+ const PackedMatrix<RhsScalar>& rhs, const Spec& spec,
+ int start_row, int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) {
+ using Kernel = Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>;
+ Kernel kernel(tuning);
+#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
+ using LhsLayout = typename Kernel::LhsLayout;
+ using RhsLayout = typename Kernel::RhsLayout;
+#endif
+ // end_row and end_col may be larger than dst dimensions.
+ // that is because kernels write directly to the destination matrix, whose
+ // dimensions may not be a multiple of the kernel dimensions, and we try to
+ // keep this annoyance localized as an implementation detail in kernels,
+ // by allowing to pass rounded-up values down as far as possible.
+ // These assertions encode the contract.
+ RUY_DCHECK_LE(0, start_row);
+ RUY_DCHECK_LE(start_row, end_row);
+ RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols);
+ RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0);
+ RUY_DCHECK_LE(0, start_col);
+ RUY_DCHECK_LE(start_col, end_col);
+ RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols);
+ RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0);
+#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL)
+ kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst);
+#else
+ for (int col = start_col; col < end_col; col += RhsLayout::kCols) {
+ int block_end_col = std::min(col + RhsLayout::kCols, end_col);
+ for (int row = start_row; row < end_row; row += LhsLayout::kCols) {
+ int block_end_row = std::min(row + LhsLayout::kCols, end_row);
+ kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst);
+ }
+ }
+#endif
+}
+
+// Main entry point for kernels.
+template <Path ThePath, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void RunKernel(Tuning tuning, const SidePair<PMatrix>& src, void* spec,
+ const SidePair<int>& start, const SidePair<int>& end,
+ DMatrix* dst) {
+ Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst);
+ RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>(
+ tuning, ToPackedMatrix<LhsScalar>(src[Side::kLhs]),
+ ToPackedMatrix<RhsScalar>(src[Side::kRhs]),
+ *static_cast<const Spec*>(spec), start[Side::kLhs], start[Side::kRhs],
+ end[Side::kLhs], end[Side::kRhs], &mdst);
+}
+
+// Copied from gemmlowp/fixedpoint.
+inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
+ std::int32_t b) {
+ bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
+ std::int64_t a_64(a);
+ std::int64_t b_64(b);
+ std::int64_t ab_64 = a_64 * b_64;
+ std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
+ std::int32_t ab_x2_high32 =
+ static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
+ return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
+}
+
+inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) {
+ std::int32_t sign = numerator >= 0 ? 1 : -1;
+ std::int32_t abs_numerator = std::abs(numerator);
+ std::int32_t mask = (1LL << exponent) - 1;
+ std::int32_t remainder = abs_numerator & mask;
+ std::int32_t threshold = mask >> 1;
+ std::int32_t abs_result =
+ (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0);
+ return sign * abs_result;
+}
+
+// Copied from TF Lite code.
+inline std::int32_t MultiplyByQuantizedMultiplier(
+ std::int32_t x, std::int32_t quantized_multiplier, int shift) {
+ int left_shift = shift > 0 ? shift : 0;
+ int right_shift = shift > 0 ? 0 : -shift;
+ return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
+ x * (1 << left_shift), quantized_multiplier),
+ right_shift);
+}
+
+// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar
+// is int32 (i.e. in all cases except floating-point) and if the destination is
+// not int32 (i.e. unless the user wants to get raw accumulators).
+template <typename Spec,
+ bool IsApplicable =
+ std::is_same<typename Spec::AccumScalar, std::int32_t>::value &&
+ !std::is_same<typename Spec::DstScalar, std::int32_t>::value>
+struct ApplyMultiplierImpl {};
+
+// Specialization in non-applicable case: do nothing, just check that values
+// are default.
+template <typename Spec>
+struct ApplyMultiplierImpl<Spec, false> {
+ using AccumScalar = typename Spec::AccumScalar;
+ using DstScalar = typename Spec::DstScalar;
+ static void Run(const Spec& spec, int row, AccumScalar* accum) {
+ RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0);
+ RUY_DCHECK_EQ(spec.multiplier_exponent, 0);
+ }
+};
+
+template <typename Spec>
+struct ApplyMultiplierImpl<Spec, true> {
+ using AccumScalar = typename Spec::AccumScalar;
+ using DstScalar = typename Spec::DstScalar;
+ static void Run(const Spec& spec, int row, AccumScalar* accum) {
+ AccumScalar m = spec.multiplier_fixedpoint_perchannel
+ ? spec.multiplier_fixedpoint_perchannel[row]
+ : spec.multiplier_fixedpoint;
+ int e = spec.multiplier_exponent_perchannel
+ ? spec.multiplier_exponent_perchannel[row]
+ : spec.multiplier_exponent;
+ *accum = MultiplyByQuantizedMultiplier(*accum, m, e);
+ }
+};
+
+template <typename Spec>
+void ApplyMultiplier(const Spec& spec, int row,
+ typename Spec::AccumScalar* accum) {
+ ApplyMultiplierImpl<Spec>::Run(spec, row, accum);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> {
+ using AccumScalar = typename Spec::AccumScalar;
+ using LhsLayout = typename Spec::StandardCppKernelLhsLayout;
+ using RhsLayout = typename Spec::StandardCppKernelRhsLayout;
+ explicit Kernel(Tuning) {}
+ void Run(const PackedMatrix<LhsScalar>& lhs,
+ const PackedMatrix<RhsScalar>& rhs, const Spec& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ // See the comment in RunKernelTyped. end_row may be larger than
+ // dst->layout.rows. It's the responsibility of the kernel to avoid
+ // overrunning dst boundaries, which we do here by computing
+ // clamped_end_row.
+ int clamped_end_row = std::min(end_row, dst->layout.rows);
+ int clamped_end_col = std::min(end_col, dst->layout.cols);
+ RUY_DCHECK_LE(0, start_row);
+ RUY_DCHECK_LE(start_row, clamped_end_row);
+ RUY_DCHECK_LE(clamped_end_row, dst->layout.rows);
+ RUY_DCHECK_LE(clamped_end_row, end_row);
+ RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols);
+ RUY_DCHECK_LE(0, start_col);
+ RUY_DCHECK_LE(start_col, clamped_end_col);
+ RUY_DCHECK_LE(clamped_end_col, dst->layout.cols);
+ RUY_DCHECK_LE(clamped_end_col, end_col);
+ RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols);
+ profiler::ScopeLabel label("Kernel (Standard Cpp)");
+ const int depth = lhs.layout.rows;
+ for (int i = start_row; i < clamped_end_row; i++) {
+ for (int j = start_col; j < clamped_end_col; j++) {
+ using AccumScalar = typename Spec::AccumScalar;
+ AccumScalar accum = 0;
+ for (int k = 0; k < depth; k++) {
+ AccumScalar lhs_val = Element(lhs, k, i);
+ AccumScalar rhs_val = Element(rhs, k, j);
+ accum += lhs_val * rhs_val;
+ }
+ if (spec.bias) {
+ accum += spec.bias[i];
+ }
+ if (lhs.zero_point) {
+ accum -= lhs.zero_point * rhs.sums[j];
+ }
+ if (rhs.zero_point) {
+ accum -= rhs.zero_point * lhs.sums[i];
+ }
+ if (lhs.zero_point && rhs.zero_point) {
+ accum += lhs.zero_point * rhs.zero_point * depth;
+ }
+ ApplyMultiplier(spec, i, &accum);
+ accum += dst->zero_point;
+ accum = std::min<AccumScalar>(accum, spec.clamp_max);
+ accum = std::max<AccumScalar>(accum, spec.clamp_min);
+ *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum);
+ }
+ }
+ }
+};
+
+#define RUY_INHERIT_KERNEL(PARENT, CHILD) \
+ template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
+ typename Spec> \
+ struct Kernel<CHILD, LhsScalar, RhsScalar, DstScalar, Spec> \
+ : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec> { \
+ explicit Kernel(Tuning tuning) \
+ : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \
+ };
+
+#if RUY_PLATFORM(NEON)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
+RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
+#elif RUY_PLATFORM(X86)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42)
+RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2)
+RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512)
+RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni)
+#endif
+
+// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
+//
+// In other cases, we still define (empty) versions, so that dummy kernels
+// can use the classes in function signatures.
+#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
+ RUY_OPT_ENABLED(RUY_OPT_ASM)) || \
+ RUY_PLATFORM(X86)
+
+#define RUY_ASM_FLAG_HAS_BIAS 0x1
+#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
+#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
+#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
+#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
+
+#define RUY_ASM_TYPE_ID_UINT8 1
+#define RUY_ASM_TYPE_ID_INT8 2
+#define RUY_ASM_TYPE_ID_INT16 3
+#define RUY_ASM_TYPE_ID_INT32 4
+
+template <typename DstScalar>
+struct DstTypeId {};
+
+template <>
+struct DstTypeId<std::uint8_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
+};
+
+template <>
+struct DstTypeId<std::int8_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
+};
+
+template <>
+struct DstTypeId<std::int16_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
+};
+
+template <>
+struct DstTypeId<std::int32_t> {
+ static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
+};
+
+template <int LhsCols, int RhsCols>
+struct KernelParams8bit {
+ static constexpr int kMaxDstTypeSize = 4;
+
+ const std::int32_t* bias;
+ const std::int32_t* lhs_sums;
+ const std::int32_t* rhs_sums;
+ const std::int8_t* lhs_base_ptr;
+ const std::int32_t* multiplier_fixedpoint;
+ const std::int32_t* multiplier_exponent;
+ const std::int8_t* rhs_base_ptr;
+ void* dst_base_ptr;
+ std::int32_t lhs_zero_point;
+ std::int32_t rhs_zero_point;
+ std::int32_t dst_zero_point;
+ std::int32_t prod_zp_depth;
+ std::int32_t start_row;
+ std::int32_t start_col;
+ std::int32_t last_row;
+ std::int32_t last_col;
+ std::int32_t dst_rows;
+ std::int32_t dst_cols;
+ std::int32_t lhs_stride;
+ std::int32_t rhs_stride;
+ std::int32_t dst_stride;
+ std::int32_t depth;
+ std::int32_t clamp_min;
+ std::int32_t clamp_max;
+ std::uint8_t flags;
+ std::uint8_t dst_type_id;
+ const std::int32_t zero_data[LhsCols] = {0};
+ std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
+ std::int32_t multiplier_fixedpoint_buf[LhsCols];
+ std::int32_t multiplier_exponent_buf[LhsCols];
+};
+
+template <typename DstScalar, int LhsCols, int RhsCols>
+void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec,
+ int start_row, int start_col, int end_row,
+ int end_col, Matrix<DstScalar>* dst,
+ KernelParams8bit<LhsCols, RhsCols>* params) {
+ using Params = KernelParams8bit<LhsCols, RhsCols>;
+
+ static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
+
+ const int depth = lhs.layout.rows;
+ RUY_DCHECK_EQ(start_row % LhsCols, 0);
+ RUY_DCHECK_EQ(start_col % RhsCols, 0);
+ RUY_DCHECK_EQ(end_row % LhsCols, 0);
+ RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+ params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+ params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+ params->flags = 0;
+ params->bias = params->zero_data;
+ if (spec.bias) {
+ params->bias = spec.bias;
+ params->flags |= RUY_ASM_FLAG_HAS_BIAS;
+ }
+ if (lhs.sums) {
+ params->lhs_sums = lhs.sums;
+ params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
+ }
+ if (rhs.sums) {
+ params->rhs_sums = rhs.sums;
+ params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
+ }
+ params->start_row = start_row;
+ params->start_col = start_col;
+ params->last_row = end_row - LhsCols;
+ params->last_col = end_col - RhsCols;
+ params->lhs_stride = lhs.layout.stride;
+ params->rhs_stride = rhs.layout.stride;
+ params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
+ params->lhs_zero_point = lhs.zero_point;
+ params->rhs_zero_point = rhs.zero_point;
+ params->dst_zero_point = dst->zero_point;
+ params->depth = depth;
+ params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
+ if (spec.multiplier_fixedpoint_perchannel) {
+ params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
+ params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
+ params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel;
+ params->multiplier_exponent = spec.multiplier_exponent_perchannel;
+ } else {
+ if (spec.multiplier_exponent > 0) {
+ params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
+ }
+ params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
+ params->multiplier_exponent = params->multiplier_exponent_buf;
+ for (int i = 0; i < LhsCols; i++) {
+ params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint;
+ params->multiplier_exponent_buf[i] = spec.multiplier_exponent;
+ }
+ }
+ params->clamp_min = spec.clamp_min;
+ params->clamp_max = spec.clamp_max;
+ params->dst_rows = dst->layout.rows;
+ params->dst_cols = dst->layout.cols;
+
+ RUY_DCHECK_LT(params->last_row, params->dst_rows);
+ RUY_DCHECK_LT(params->last_col, params->dst_cols);
+
+ params->dst_type_id = DstTypeId<DstScalar>::kValue;
+ params->dst_base_ptr =
+ dst->data.get() + start_col * dst->layout.stride + start_row;
+}
+
+template <int LhsCols, int RhsCols>
+struct KernelParamsFloat {
+ const float* lhs_base_ptr;
+ const float* rhs_base_ptr;
+ float* dst_base_ptr;
+ const float* bias;
+ std::int32_t start_row;
+ std::int32_t start_col;
+ std::int32_t last_row;
+ std::int32_t last_col;
+ std::int32_t dst_rows;
+ std::int32_t dst_cols;
+ std::int32_t lhs_stride;
+ std::int32_t rhs_stride;
+ std::int32_t dst_stride;
+ std::int32_t depth;
+ float clamp_min;
+ float clamp_max;
+ std::uint8_t flags;
+ const float zero_data[LhsCols] = {0};
+ float dst_tmp_buf[LhsCols * RhsCols];
+};
+
+template <int LhsCols, int RhsCols>
+inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
+ const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec,
+ int start_row, int start_col, int end_row,
+ int end_col, Matrix<float>* dst,
+ KernelParamsFloat<LhsCols, RhsCols>* params) {
+ const int depth = lhs.layout.rows;
+ RUY_DCHECK_EQ(start_row % LhsCols, 0);
+ RUY_DCHECK_EQ(start_col % RhsCols, 0);
+ RUY_DCHECK_EQ(end_row % LhsCols, 0);
+ RUY_DCHECK_EQ(end_col % RhsCols, 0);
+
+ params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+ params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
+ params->dst_base_ptr =
+ dst->data.get() + start_col * dst->layout.stride + start_row;
+
+ std::uint8_t flags = 0;
+ params->bias = params->zero_data;
+ if (spec.bias) {
+ params->bias = spec.bias;
+ flags |= RUY_ASM_FLAG_HAS_BIAS;
+ }
+ params->flags = flags;
+ params->start_row = start_row;
+ params->start_col = start_col;
+ params->last_row = end_row - LhsCols;
+ params->last_col = end_col - RhsCols;
+ params->lhs_stride = sizeof(float) * lhs.layout.stride;
+ params->rhs_stride = sizeof(float) * rhs.layout.stride;
+ params->dst_stride = sizeof(float) * dst->layout.stride;
+ params->depth = depth;
+ params->clamp_min = spec.clamp_min;
+ params->clamp_max = spec.clamp_max;
+ params->dst_rows = dst->layout.rows;
+ params->dst_cols = dst->layout.cols;
+
+ RUY_DCHECK_LT(params->last_row, params->dst_rows);
+ RUY_DCHECK_LT(params->last_col, params->dst_cols);
+}
+
+#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
+ // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86)
+
+template <int LhsCols, int RhsCols>
+struct KernelParams8bit {};
+
+template <int LhsCols, int RhsCols>
+struct KernelParamsFloat {};
+
+#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) &&
+ // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_
diff --git a/ruy/kernel_sse42.cc b/ruy/kernel_sse42.cc
new file mode 100644
index 0000000..747ca1c
--- /dev/null
+++ b/ruy/kernel_sse42.cc
@@ -0,0 +1,428 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+static constexpr int kAvxFloatBlockSize = 8;
+static constexpr int kAvx8bitBlockSize = 8;
+static constexpr int kAvx8bitInnerSize = 4;
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)");
+ std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize];
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ // Initialize with bias.
+ std::int32_t initial_accum_data[kAvx8bitBlockSize];
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ initial_accum_data[i] = 0;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x];
+ rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x];
+ }
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x];
+ }
+ }
+ }
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.rhs_zero_point * params.lhs_sums[row + i];
+ }
+ }
+ }
+ if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.lhs_zero_point * params.rhs_sums[col + j];
+ }
+ }
+ }
+ if (params.lhs_zero_point && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.prod_zp_depth;
+ }
+ }
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ std::int32_t m_vector[kAvx8bitBlockSize];
+ std::int32_t e_vector[kAvx8bitBlockSize];
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ int i = 0;
+ for (; i < residual_rows; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[row + i];
+ e_vector[i] = params.multiplier_exponent[row + i];
+ }
+ for (; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = m_vector[0];
+ e_vector[i] = e_vector[0];
+ }
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[i];
+ e_vector[i] = params.multiplier_exponent[i];
+ }
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = MultiplyByQuantizedMultiplier(
+ accum_data[j][i], m_vector[i], e_vector[i]);
+ }
+ }
+
+ if (params.dst_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.dst_zero_point;
+ }
+ }
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<std::int32_t>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<std::int32_t>(accum_data[j][i], params.clamp_min);
+ }
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr =
+ store_full_block
+ ? static_cast<std::int8_t*>(dst_ptr)
+ : const_cast<std::int8_t*>(
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride / sizeof(std::int8_t)
+ : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::int8_t* block_ptr =
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::int8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = store_full_block
+ ? static_cast<std::uint8_t*>(dst_ptr)
+ : const_cast<std::uint8_t*>(
+ reinterpret_cast<const std::uint8_t*>(
+ params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::uint8_t* block_ptr =
+ reinterpret_cast<const std::uint8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::uint8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ if (store_full_block) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int16_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int16_t* tmp_ptr = const_cast<std::int16_t*>(
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf));
+ const int block_col_offset = kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ const std::int16_t* block_ptr =
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf);
+ std::int16_t* dst_block_ptr = static_cast<std::int16_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = block_ptr[i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int16_t);
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int32_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = accum_data[j][i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int32_t);
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)");
+
+ float lhs_data[kAvxFloatBlockSize];
+ float rhs_data[kAvxFloatBlockSize];
+ float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize];
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ const float* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvxFloatBlockSize) {
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ const float* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvxFloatBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvxFloatBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvxFloatBlockSize);
+
+ // Initialize with bias.
+ float initial_accum_data[kAvxFloatBlockSize];
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ initial_accum_data[i] = 0.0f;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ lhs_data[i] = lhs_ptr[i];
+ rhs_data[i] = rhs_ptr[i];
+ }
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] += lhs_data[i] * rhs_data[j];
+ }
+ }
+ lhs_ptr += kAvxFloatBlockSize;
+ rhs_ptr += kAvxFloatBlockSize;
+ }
+
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<float>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<float>(accum_data[j][i], params.clamp_min);
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvxFloatBlockSize) &&
+ (residual_cols == kAvxFloatBlockSize);
+
+ {
+ float* block_ptr =
+ store_full_block ? dst_ptr : const_cast<float*>(params.dst_tmp_buf);
+ const int block_col_offset = store_full_block
+ ? params.dst_stride / sizeof(float)
+ : kAvxFloatBlockSize;
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ block_ptr[i] = accum_data[j][i];
+ }
+ block_ptr += block_col_offset;
+ }
+ }
+ if (!store_full_block) {
+ const float* block_ptr = params.dst_tmp_buf;
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i];
+ }
+ block_ptr += kAvxFloatBlockSize;
+ }
+ }
+
+ lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float);
+ dst_ptr += kAvxFloatBlockSize;
+ } // End row-block loop.
+
+ dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float);
+ rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float);
+ } // End col-block loop.
+}
+
+#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
new file mode 100644
index 0000000..dbcf42b
--- /dev/null
+++ b/ruy/kernel_x86.h
@@ -0,0 +1,222 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_
+
+#include <cstdint>
+
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/kernel_common.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/spec.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+void Kernel8bitSse42(const KernelParams8bit<8, 8>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kSse42, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ Kernel8bitSse42(params);
+ }
+};
+
+void KernelFloatSse42(const KernelParamsFloat<8, 8>& params);
+
+template <>
+struct Kernel<Path::kSse42, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ KernelFloatSse42(params);
+ }
+};
+
+void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitAvx512SingleCol(params);
+ } else {
+ Kernel8bitAvx512(params);
+ }
+ }
+};
+
+void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
+
+template <>
+struct Kernel<Path::kAvx512, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1) {
+ KernelFloatAvx512SingleCol(params);
+ } else {
+ KernelFloatAvx512(params);
+ }
+ }
+};
+
+void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvx2, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitAvx2SingleCol(params);
+ } else {
+ Kernel8bitAvx2(params);
+ }
+ }
+};
+
+void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
+
+template <>
+struct Kernel<Path::kAvx2, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ if (dst->layout.cols == 1) {
+ KernelFloatAvx2SingleCol(params);
+ } else {
+ KernelFloatAvx2(params);
+ }
+ }
+};
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params);
+
+template <typename DstScalar>
+struct Kernel<Path::kAvxVnni, std::int8_t, std::int8_t, DstScalar,
+ BasicSpec<std::int32_t, DstScalar>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<std::int8_t>& lhs,
+ const PackedMatrix<std::int8_t>& rhs,
+ const BasicSpec<std::int32_t, DstScalar>& spec, int start_row,
+ int start_col, int end_row, int end_col,
+ Matrix<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
+ dst, &params);
+ Kernel8bitAvxVnni(params);
+ }
+};
+
+void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params);
+
+template <>
+struct Kernel<Path::kAvxVnni, float, float, float, BasicSpec<float, float>> {
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ const BasicSpec<float, float>& spec, int start_row, int start_col,
+ int end_row, int end_col, Matrix<float>* dst) const {
+ KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
+ end_col, dst, &params);
+ KernelFloatAvxVnni(params);
+ }
+};
+
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_
diff --git a/ruy/matrix.h b/ruy/matrix.h
new file mode 100644
index 0000000..2dcb081
--- /dev/null
+++ b/ruy/matrix.h
@@ -0,0 +1,182 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_
+
+#include <cstddef>
+#include <cstdint> // IWYU pragma: keep
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+
+namespace ruy {
+
+// Layout storage order. Here and elsewhere, 'col' is short for 'column'.
+// 'column-major' means that each column is contiguous in memory.
+enum class Order : std::uint8_t { kColMajor, kRowMajor };
+
+// Describes the shape and storage layout of a matrix.
+struct Layout final {
+ std::int32_t rows = 0;
+ std::int32_t cols = 0;
+ // Stride is the offset between two adjacent matrix elements
+ // in the non-contiguous direction.
+ std::int32_t stride = 0;
+ Order order = Order::kColMajor;
+};
+
+namespace detail {
+
+// Thin wrapper around a pointer that tracks its constness dynamically.
+//
+// This is our take on the C++ problem of enforcing constness of data
+// wrapped in a containers class: it's not worth the hassle of trying to
+// make it fully work at compile-time.
+// Instead, we only enforce constness at runtime, and to make it
+// zero-overhead, we only enforce it in debug builds.
+template <typename T>
+class ConstCheckingPtr final {
+ public:
+ using element_type = T;
+
+ // Convenience methods. Most `set` calls go through these.
+ ConstCheckingPtr& operator=(T* ptr) {
+ set(ptr);
+ return *this;
+ }
+ ConstCheckingPtr& operator=(const T* ptr) {
+ set(ptr);
+ return *this;
+ }
+ ConstCheckingPtr& operator=(std::nullptr_t) {
+ set(static_cast<T*>(nullptr));
+ return *this;
+ }
+
+ // Core accessors. These encapsulate the main logic:
+ // - for `set`, the constness of the argument determines whether internal
+ // pointer should be tracked as const/mutable.
+ // - for `get`, the constness of `this` determines whether the call
+ // counts as a const or mutable use of the internal pointer.
+ void set(T* ptr) {
+ ptr_ = ptr;
+ set_mutable(true);
+ }
+ void set(const T* ptr) {
+ ptr_ = ptr;
+ set_mutable(false);
+ }
+ T* get() /* NOT const */ {
+ assert_mutable();
+ return const_cast<T*>(ptr_);
+ }
+ const T* get() const { return ptr_; }
+
+ private:
+ static_assert(!std::is_const<T>::value, "");
+ const T* ptr_ = nullptr;
+#ifndef NDEBUG
+ bool is_mutable_ = true;
+ void set_mutable(bool val) { is_mutable_ = val; }
+ void assert_mutable() { RUY_DCHECK(is_mutable_); }
+#else
+ void set_mutable(bool) {}
+ void assert_mutable() {}
+#endif
+};
+
+} // namespace detail
+
+// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map':
+// it merely wraps existing data as a matrix. It doesn't own any buffer.
+// Scalar may be any floating-point or integral type. When integral, it may be
+// signed or unsigned.
+template <typename Scalar>
+struct Matrix final {
+ Matrix& operator=(const Matrix& other) {
+ data = other.data;
+ cacheable = other.cacheable;
+ layout = other.layout;
+ zero_point = other.zero_point;
+ return *this;
+ }
+
+ // The underlying buffer wrapped by this matrix.
+ detail::ConstCheckingPtr<Scalar> data;
+ // The shape and data layout of this matrix.
+ Layout layout;
+ // The zero_point, i.e. which Scalar value is to be interpreted as zero.
+ // When Scalar is floating-point, this must be 0.
+ Scalar zero_point = 0;
+ // Clients of Ruy must set this flag to enable any caching behavior. Doesn't
+ // impact numerical results, but caching can impact observable metrics like
+ // latency, memory usage, power, etc.
+ bool cacheable = false;
+};
+
+inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
+ layout->rows = rows;
+ layout->cols = cols;
+ layout->order = order;
+ layout->stride = order == Order::kColMajor ? rows : cols;
+}
+
+// Opaque data structure representing a pre-packed matrix, as obtained from
+// Ruy's advanced API.
+struct PrepackedMatrix {
+ void* data = nullptr;
+ std::size_t data_size = 0;
+ void* sums = nullptr;
+ std::size_t sums_size = 0;
+};
+
+template <typename StreamType, typename Scalar>
+StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) {
+ for (int row = 0; row < mat.layout.rows; row++) {
+ for (int col = 0; col < mat.layout.cols; col++) {
+ stream << static_cast<double>(Element(mat, row, col)) << " ";
+ }
+ stream << "\n";
+ }
+ return stream;
+}
+
+// Compile-time version of KernelLayout, used to declare kernel layouts in a
+// way that can be consumed by compile-time logic.
+// See how partial specializations of Kernel use it to declare their layouts.
+// The only reason why this is currently part of the public API is to
+// allow testing various layouts for the Path::kStandardCpp kernel, as a
+// testing-only feature. See Spec::StandardCppKernelLhsLayout.
+template <Order tOrder, int tRows, int tCols>
+struct FixedKernelLayout {
+ static constexpr Order kOrder = tOrder;
+ static constexpr int kRows = tRows;
+ static constexpr int kCols = tCols;
+};
+
+#if (__cplusplus < 201703L)
+// A static constexpr data member is automatically inline and should not require
+// redeclaration without an initializer. This is actually deprecated from C++17
+// onwards. Clang with -O0 without this can fail to link.
+template <Order tOrder, int tRows, int tCols>
+constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kCols;
+template <Order tOrder, int tRows, int tCols>
+constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows;
+#endif
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_
diff --git a/ruy/opt_set.h b/ruy/opt_set.h
new file mode 100644
index 0000000..fef0107
--- /dev/null
+++ b/ruy/opt_set.h
@@ -0,0 +1,51 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_
+
+// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling
+// certain optimizations. It should be used by defining that macro on the
+// compiler command line.
+//
+// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy.
+#define RUY_OPT_INTRINSICS 0x1
+#define RUY_OPT_ASM 0x2
+#define RUY_OPT_TUNING 0x4
+#define RUY_OPT_FAT_KERNEL 0x8
+#define RUY_OPT_NATIVE_ROUNDING 0x10
+#define RUY_OPT_AVOID_ALIASING 0x20
+#define RUY_OPT_MAX_STREAMING 0x40
+#define RUY_OPT_PACK_AHEAD 0x80
+#define RUY_OPT_PREFETCH_LOAD 0x100
+#define RUY_OPT_PREFETCH_STORE 0x200
+#define RUY_OPT_FRACTAL_Z 0x400
+#define RUY_OPT_FRACTAL_U 0x800
+#define RUY_OPT_FRACTAL_HILBERT 0x1000
+
+#if !defined(RUY_OPT_SET)
+#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK
+// Load prefetching is detrimental in matrix multiplication benchmarks.
+// Store prefetching is not.
+#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD)
+#else
+// Default to all optimizations.
+#define RUY_OPT_SET (~0)
+#endif
+#endif
+
+#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0)
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_
diff --git a/ruy/pack.h b/ruy/pack.h
new file mode 100644
index 0000000..e066663
--- /dev/null
+++ b/ruy/pack.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_
+
+#include "ruy/platform.h"
+
+// IWYU pragma: begin_exports
+#if RUY_PLATFORM(NEON)
+#include "ruy/pack_arm.h"
+#elif RUY_PLATFORM(X86)
+#include "ruy/pack_x86.h"
+#else
+#include "ruy/pack_common.h"
+#endif
+// IWYU pragma: end_exports
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_
diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc
new file mode 100644
index 0000000..8b68a39
--- /dev/null
+++ b/ruy/pack_arm.cc
@@ -0,0 +1,1936 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdint>
+
+#include "ruy/common.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)");
+ asm volatile(
+ // clang-format off
+ "dup v26.16b, %w[input_xor]\n"
+ "mov w1, #0\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ "and w2, %w[rows], #-16\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+
+ "add w1, w1, #16\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "cmp w1, w2\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "beq 2f\n"
+
+ "1:\n"
+
+ "add w1, w1, #16\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ "cmp w1, w2\n"
+ "sadalp v29.4s, v17.8h\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #15\n"
+ "beq 4f\n"
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "saddlp v17.8h, v5.16b\n"
+ "saddlp v18.8h, v6.16b\n"
+ "saddlp v19.8h, v7.16b\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "str q4, [%[packed_ptr], #0]\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "4:\n"
+
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+ "addp v28.4s, v28.4s, v30.4s\n"
+
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+ [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+ [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point),
+ [ input_xor ] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+}
+#endif
+
+#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#define RUY_OFFSET_SRC_PTR0 0
+#define RUY_OFFSET_SRC_PTR1 4
+#define RUY_OFFSET_SRC_PTR2 8
+#define RUY_OFFSET_SRC_PTR3 12
+#define RUY_OFFSET_SUMS_PTR 16
+#define RUY_OFFSET_PACKED_PTR 20
+#define RUY_OFFSET_SRC_INC0 24
+#define RUY_OFFSET_SRC_INC1 28
+#define RUY_OFFSET_SRC_INC2 32
+#define RUY_OFFSET_SRC_INC3 36
+#define RUY_OFFSET_SRC_ROWS 40
+#define RUY_OFFSET_SRC_ZERO_POINT 44
+#define RUY_OFFSET_INPUT_XOR 48
+
+template <typename Params>
+void CheckOffsetsInPackParams8bit(const Params&) {
+ static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, "");
+ static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, "");
+ static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, "");
+ static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, "");
+ static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, "");
+ static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, "");
+ static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, "");
+ static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, "");
+ static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, "");
+ static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, "");
+ static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, "");
+ static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT,
+ "");
+ static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, "");
+}
+
+// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9.
+// No attempt made at making this code efficient on in-order cores yet.
+void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) {
+ CheckOffsetsInPackParams8bit(params);
+ profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)");
+ const void* src_ptr0 = params.src_ptr0;
+ const void* src_ptr1 = params.src_ptr1;
+ const void* src_ptr2 = params.src_ptr2;
+ const void* src_ptr3 = params.src_ptr3;
+ const int src_inc0 = params.src_inc0;
+ const int src_inc1 = params.src_inc1;
+ const int src_inc2 = params.src_inc2;
+ const int src_inc3 = params.src_inc3;
+ const std::int8_t* packed_ptr = params.packed_ptr;
+
+ asm volatile(
+ // clang-format off
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
+ "vdup.8 q11, r2\n"
+ "mov r1, #0\n"
+ // Zero-out the accumulators
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+ "vmov.i32 q14, #0\n"
+ "vmov.i32 q15, #0\n"
+
+ // Round down src_rows to nearest multiple of 16.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "and r2, r3, #-16\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+
+ "1:\n"
+ "add r1, r1, #16\n"
+ /* Load q0 */
+ "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
+ "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n")
+
+ /* Load q1 */
+ "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
+ "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n")
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ // Pairwise add in to 16b accumulators.
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q12 and q13 contain 4x32b accumulators
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ // Now do the same for src_ptr2 and src_ptr3.
+ "vld1.8 {d0, d1}, [%[src_ptr2]]\n"
+ "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n")
+
+ "vld1.8 {d2, d3}, [%[src_ptr3]]\n"
+ "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n")
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q14 and q15 contain 4x32b accumulators
+ "vpadal.s16 q14, q8\n"
+ "vpadal.s16 q15, q9\n"
+
+ "cmp r1, r2\n"
+ "bne 1b\n"
+
+ "3:\n"
+
+ // Now pack the last (num_rows % 16) rows.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "ands r2, r3, #15\n"
+ "beq 4f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// First, read/accumulate/write for src_ptr0 and src_ptr1.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Reset to src_zero for src_ptr2 and src_ptr3.
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// Next, read/accumulate/write for src_ptr2 and src_ptr3.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q14, q8\n"
+ "vpadal.s16 q15, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ "4:\n"
+ // Pairwise add 32-bit accumulators
+ "vpadd.i32 d24, d24, d25\n"
+ "vpadd.i32 d26, d26, d27\n"
+ "vpadd.i32 d28, d28, d29\n"
+ "vpadd.i32 d30, d30, d31\n"
+ // Final 32-bit values per row
+ "vpadd.i32 d25, d24, d26\n"
+ "vpadd.i32 d27, d28, d30\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
+ "cmp r3, #0\n"
+ "beq 6f\n"
+ "vst1.32 {d25}, [r3]!\n"
+ "vst1.32 {d27}, [r3]!\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3)
+ : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
+ [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3),
+ [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
+ : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
+}
+
+// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9.
+// No attempt made at making this code efficient on in-order cores yet.
+// This version differs from the above in that we only handle two columns
+// at a time.
+void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) {
+ CheckOffsetsInPackParams8bit(params);
+ profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)");
+ const void* src_ptr0 = params.src_ptr0;
+ const void* src_ptr1 = params.src_ptr1;
+ const int src_inc0 = params.src_inc0;
+ const int src_inc1 = params.src_inc1;
+ const std::int8_t* packed_ptr = params.packed_ptr;
+
+ asm volatile(
+ // clang-format off
+
+ "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
+ "vdup.8 q11, r2\n"
+ "mov r1, #0\n"
+ // Zero-out the accumulators
+ "vmov.i32 q12, #0\n"
+ "vmov.i32 q13, #0\n"
+
+ // Round down src_rows to nearest multiple of 16.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "and r2, r3, #-16\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+
+ "1:\n"
+ "add r1, r1, #16\n"
+ /* Load q0 */
+ "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
+ "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
+
+ /* Load q1 */
+ "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
+ "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
+
+ "veor.8 q4, q0, q11\n"
+ "veor.8 q5, q1, q11\n"
+
+ // Pairwise add in to 16b accumulators.
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ // Pairwise add accumulate into 32b accumulators.
+ // q12 and q13 contain 4x32b accumulators
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "cmp r1, r2\n"
+
+ "bne 1b\n"
+
+ "3:\n"
+
+ // Now pack the last (num_rows % 16) rows.
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
+ "ands r2, r3, #15\n"
+ "beq 4f\n"
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
+ "vdup.8 q0, r3\n"
+ "vdup.8 q1, r3\n"
+
+// Read/accumulate/write for src_ptr0 and src_ptr1.
+#define RUY_LOAD_ONE_ROW1(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW1(0, 0)
+ RUY_LOAD_ONE_ROW1(1, 1)
+ RUY_LOAD_ONE_ROW1(2, 2)
+ RUY_LOAD_ONE_ROW1(3, 3)
+ RUY_LOAD_ONE_ROW1(4, 4)
+ RUY_LOAD_ONE_ROW1(5, 5)
+ RUY_LOAD_ONE_ROW1(6, 6)
+ RUY_LOAD_ONE_ROW1(7, 7)
+#undef RUY_LOAD_ONE_ROW1
+
+#define RUY_LOAD_ONE_ROW2(I, R) \
+ "cmp r2, #" #I "\n" \
+ "beq 5f\n" \
+ "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
+ "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
+
+ RUY_LOAD_ONE_ROW2(8, 0)
+ RUY_LOAD_ONE_ROW2(9, 1)
+ RUY_LOAD_ONE_ROW2(10, 2)
+ RUY_LOAD_ONE_ROW2(11, 3)
+ RUY_LOAD_ONE_ROW2(12, 4)
+ RUY_LOAD_ONE_ROW2(13, 5)
+ RUY_LOAD_ONE_ROW2(14, 6)
+ RUY_LOAD_ONE_ROW2(15, 7)
+#undef RUY_LOAD_ONE_ROW2
+
+ "5:\n"
+
+ "veor.16 q4, q0, q11\n"
+ "veor.16 q5, q1, q11\n"
+
+ "vpaddl.s8 q8, q4\n"
+ "vpaddl.s8 q9, q5\n"
+
+
+ // Pairwise add accumulate to 4x32b accumulators.
+ "vpadal.s16 q12, q8\n"
+ "vpadal.s16 q13, q9\n"
+
+ "vst1.32 {q4}, [%[packed_ptr]]!\n"
+ "vst1.32 {q5}, [%[packed_ptr]]!\n"
+
+ "4:\n"
+
+ // Pairwise add 32-bit accumulators
+ "vpadd.i32 d24, d24, d25\n"
+ "vpadd.i32 d26, d26, d27\n"
+ // Final 32-bit values per row
+ "vpadd.i32 d25, d24, d26\n"
+
+ "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
+ "cmp r3, #0\n"
+ "beq 6f\n"
+ "vst1.32 {d25}, [r3]!\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1)
+ : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
+ [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
+ : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
+}
+
+#undef RUY_OFFSET_SRC_PTR0
+#undef RUY_OFFSET_SRC_PTR1
+#undef RUY_OFFSET_SRC_PTR2
+#undef RUY_OFFSET_SRC_PTR32
+#undef RUY_OFFSET_SUMS_PTR
+#undef RUY_OFFSET_PACKED_PTR0
+#undef RUY_OFFSET_SRC_INC0
+#undef RUY_OFFSET_SRC_INC1
+#undef RUY_OFFSET_SRC_INC2
+#undef RUY_OFFSET_SRC_INC3
+#undef RUY_OFFSET_SRC_ROWS
+#undef RUY_OFFSET_SRC_ZERO_POINT
+#undef RUY_OFFSET_INPUT_XOR
+
+#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
+ asm volatile(
+ // clang-format off
+ "dup v26.16b, %w[input_xor]\n"
+ "mov w1, #0\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ "and w2, %w[rows], #-16\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ "add w1, w1, #16\n"
+ "cmp w1, w2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add w1, w1, #16\n"
+ "ins v0.d[1], x10\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ins v1.d[1], x11\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ins v2.d[1], x12\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ins v3.d[1], x13\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ "cmp w1, w2\n"
+ "sadalp v29.4s, v17.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+ "sadalp v30.4s, v18.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "sadalp v31.4s, v19.8h\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+ "bne 1b\n"
+
+ "2:\n"
+ "ins v0.d[1], x10\n"
+ "ins v1.d[1], x11\n"
+ "ins v2.d[1], x12\n"
+ "ins v3.d[1], x13\n"
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "str q4, [%[packed_ptr], #0]\n"
+ "saddlp v17.8h, v5.16b\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "saddlp v18.8h, v6.16b\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "saddlp v19.8h, v7.16b\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #15\n"
+ "beq 4f\n"
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+
+ "saddlp v16.8h, v4.16b\n"
+ "saddlp v17.8h, v5.16b\n"
+ "saddlp v18.8h, v6.16b\n"
+ "saddlp v19.8h, v7.16b\n"
+ "sadalp v28.4s, v16.8h\n"
+ "sadalp v29.4s, v17.8h\n"
+ "sadalp v30.4s, v18.8h\n"
+ "sadalp v31.4s, v19.8h\n"
+
+ "str q4, [%[packed_ptr], #0]\n"
+ "str q5, [%[packed_ptr], #16]\n"
+ "str q6, [%[packed_ptr], #32]\n"
+ "str q7, [%[packed_ptr], #48]\n"
+ "add %[packed_ptr], %[packed_ptr], #64\n"
+
+ "4:\n"
+
+ "addp v28.4s, v28.4s, v29.4s\n"
+ "addp v30.4s, v30.4s, v31.4s\n"
+ "addp v28.4s, v28.4s, v30.4s\n"
+
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [ rows ] "r"(src_rows),
+ [ src_zero_point ] "r"(src_zero_point),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
+ "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
+ "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
+ "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col,
+ int end_col, std::int32_t* sums_ptr,
+ int input_xor) {
+ profiler::ScopeLabel label(
+ "Pack (kNeonDotprod, optimized for in-order cores)");
+ asm volatile(
+ // clang-format off
+ "dup v26.16b, %w[input_xor]\n"
+ "mov w1, #1\n"
+ "dup v27.16b, w1\n"
+ "mov w1, #0\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ "and w2, %w[rows], #-16\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ "add w1, w1, #16\n"
+ "cmp w1, w2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add w1, w1, #16\n"
+ "ins v0.d[1], x10\n"
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "ins v1.d[1], x11\n"
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "ins v2.d[1], x12\n"
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "ins v3.d[1], x13\n"
+ "ldr x13, [%[src_ptr3], #8]\n"
+
+ "eor v4.16b, v0.16b, v26.16b\n"
+ "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
+ "eor v5.16b, v1.16b, v26.16b\n"
+ "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
+ "eor v6.16b, v2.16b, v26.16b\n"
+ "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
+ "eor v7.16b, v3.16b, v26.16b\n"
+ "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
+
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ "cmp w1, w2\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+ "ins v0.d[1], x10\n"
+ "ins v1.d[1], x11\n"
+ "ins v2.d[1], x12\n"
+ "ins v3.d[1], x13\n"
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #15\n"
+ "beq 4f\n"
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "cmp w2, #4\n"
+ "ble 4f\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "cmp w2, #8\n"
+ "ble 4f\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "cmp w2, #12\n"
+ "ble 4f\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "4:\n"
+
+ "add v28.4s, v28.4s, v29.4s\n"
+ "add v30.4s, v30.4s, v31.4s\n"
+ "add v28.4s, v28.4s, v30.4s\n"
+
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+ [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [rows] "r"(src_rows),
+ [src_zero_point] "r"(static_cast<int>(src_zero_point)),
+ [input_xor] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+
+void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor) {
+ profiler::ScopeLabel label(
+ "Pack (kNeonDotprod, optimized for out-of-order cores)");
+ asm volatile(
+ // clang-format off
+ "dup v26.16b, %w[input_xor]\n"
+ "mov w1, #1\n"
+ "dup v27.16b, w1\n"
+ "mov w1, #0\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+ "and w2, %w[rows], #-64\n"
+ "cmp w1, w2\n"
+ "beq 9f\n"
+
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #64\n"
+ "cmp w1, w2\n"
+ "beq 8f\n"
+
+ "7:\n"
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v4.16b, v4.16b, v26.16b\n"
+ "eor v5.16b, v5.16b, v26.16b\n"
+ "eor v6.16b, v6.16b, v26.16b\n"
+ "eor v7.16b, v7.16b, v26.16b\n"
+
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+
+ "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v8.16b, v8.16b, v26.16b\n"
+ "eor v9.16b, v9.16b, v26.16b\n"
+ "eor v10.16b, v10.16b, v26.16b\n"
+ "eor v11.16b, v11.16b, v26.16b\n"
+
+ "trn1 v16.4s, v8.4s, v9.4s\n"
+ "trn2 v17.4s, v8.4s, v9.4s\n"
+ "trn1 v18.4s, v10.4s, v11.4s\n"
+ "trn2 v19.4s, v10.4s, v11.4s\n"
+
+ "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v12.16b, v12.16b, v26.16b\n"
+ "eor v13.16b, v13.16b, v26.16b\n"
+ "eor v14.16b, v14.16b, v26.16b\n"
+ "eor v15.16b, v15.16b, v26.16b\n"
+
+ "trn1 v16.4s, v12.4s, v13.4s\n"
+ "trn2 v17.4s, v12.4s, v13.4s\n"
+ "trn1 v18.4s, v14.4s, v15.4s\n"
+ "trn2 v19.4s, v14.4s, v15.4s\n"
+
+ "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "cmp w1, w2\n"
+ "bne 7b\n"
+
+ "8:\n"
+
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v4.16b, v4.16b, v26.16b\n"
+ "eor v5.16b, v5.16b, v26.16b\n"
+ "eor v6.16b, v6.16b, v26.16b\n"
+ "eor v7.16b, v7.16b, v26.16b\n"
+
+ "trn1 v16.4s, v4.4s, v5.4s\n"
+ "trn2 v17.4s, v4.4s, v5.4s\n"
+ "trn1 v18.4s, v6.4s, v7.4s\n"
+ "trn2 v19.4s, v6.4s, v7.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v8.16b, v8.16b, v26.16b\n"
+ "eor v9.16b, v9.16b, v26.16b\n"
+ "eor v10.16b, v10.16b, v26.16b\n"
+ "eor v11.16b, v11.16b, v26.16b\n"
+
+ "trn1 v16.4s, v8.4s, v9.4s\n"
+ "trn2 v17.4s, v8.4s, v9.4s\n"
+ "trn1 v18.4s, v10.4s, v11.4s\n"
+ "trn2 v19.4s, v10.4s, v11.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "eor v12.16b, v12.16b, v26.16b\n"
+ "eor v13.16b, v13.16b, v26.16b\n"
+ "eor v14.16b, v14.16b, v26.16b\n"
+ "eor v15.16b, v15.16b, v26.16b\n"
+
+ "trn1 v16.4s, v12.4s, v13.4s\n"
+ "trn2 v17.4s, v12.4s, v13.4s\n"
+ "trn1 v18.4s, v14.4s, v15.4s\n"
+ "trn2 v19.4s, v14.4s, v15.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "9:\n"
+#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
+ "and w2, %w[rows], #-16\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+ "cmp w1, w2\n"
+ "beq 2f\n"
+
+ "1:\n"
+
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #16\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "cmp w1, w2\n"
+ "bne 1b\n"
+
+ "2:\n"
+
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #15\n"
+ "beq 4f\n"
+ "dup v0.16b, %w[src_zero_point]\n"
+ "dup v1.16b, %w[src_zero_point]\n"
+ "dup v2.16b, %w[src_zero_point]\n"
+ "dup v3.16b, %w[src_zero_point]\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
+ "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
+ "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
+ "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+ RUY_LOAD_ONE_ROW(4)
+ RUY_LOAD_ONE_ROW(5)
+ RUY_LOAD_ONE_ROW(6)
+ RUY_LOAD_ONE_ROW(7)
+ RUY_LOAD_ONE_ROW(8)
+ RUY_LOAD_ONE_ROW(9)
+ RUY_LOAD_ONE_ROW(10)
+ RUY_LOAD_ONE_ROW(11)
+ RUY_LOAD_ONE_ROW(12)
+ RUY_LOAD_ONE_ROW(13)
+ RUY_LOAD_ONE_ROW(14)
+ RUY_LOAD_ONE_ROW(15)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "eor v0.16b, v0.16b, v26.16b\n"
+ "eor v1.16b, v1.16b, v26.16b\n"
+ "eor v2.16b, v2.16b, v26.16b\n"
+ "eor v3.16b, v3.16b, v26.16b\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "cmp w2, #4\n"
+ "ble 4f\n"
+ ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "cmp w2, #8\n"
+ "ble 4f\n"
+ ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "cmp w2, #12\n"
+ "ble 4f\n"
+ ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "4:\n"
+
+ "add v28.4s, v28.4s, v29.4s\n"
+ "add v30.4s, v30.4s, v31.4s\n"
+ "add v28.4s, v28.4s, v30.4s\n"
+
+ "cmp %[sums_ptr], #0\n"
+ "beq 6f\n"
+ "st1 {v28.4s}, [%[sums_ptr]], #16\n"
+ "6:\n"
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+ [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+ [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [ rows ] "r"(src_rows),
+ [ src_zero_point ] "r"(static_cast<int>(src_zero_point)),
+ [ input_xor ] "r"(input_xor)
+ : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+}
+
+#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)");
+ asm volatile(
+ // clang-format off
+ "mov w1, #0\n"
+
+ "and w2, %w[rows], #-4\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+ "add w1, w1, #4\n"
+ "cmp w1, w2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add w1, w1, #4\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ "cmp w1, w2\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #3\n"
+ "beq 4f\n"
+ "dup v0.16b, wzr\n"
+ "dup v1.16b, wzr\n"
+ "dup v2.16b, wzr\n"
+ "dup v3.16b, wzr\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+ "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+ "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+ "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "mov x1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp w2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+ RUY_STORE_ONE_ROW(0, v20)
+ RUY_STORE_ONE_ROW(1, v21)
+ RUY_STORE_ONE_ROW(2, v22)
+ RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)),
+ [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
+ [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)),
+ [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
+ [ rows ] "r"(src_rows)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
+ "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
+ "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif
+
+#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col,
+ int output_stride) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)");
+ asm volatile(
+ // clang-format off
+ "mov r1, #0\n"
+ "and r2, %[rows], #-4\n"
+ "cmp r1, r2\n"
+ "beq 3f\n"
+#define RUY_LOAD_FOUR_BY_FOUR() \
+ /* Load q0 */ \
+ "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \
+ /* if src_inc0 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #1\n" \
+ "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
+ /* Load q1 */ \
+ "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \
+ /* if src_inc1 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #2\n" \
+ "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
+ /* Load q2 */ \
+ "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \
+ /* if src_inc2 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #4\n" \
+ "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
+ /* Load q3 */ \
+ "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \
+ /* if src_inc3 != 0, add 16 to src_ptr0 */ \
+ "and r3, %[src_inc], #8\n" \
+ "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
+
+ RUY_LOAD_FOUR_BY_FOUR()
+ "add r1, r1, #4\n"
+ "cmp r1, r2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add r1, r1, #4\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ RUY_LOAD_FOUR_BY_FOUR()
+#undef RUY_LOAD_FOUR_BY_FOUR
+
+#define RUY_STORE_FOUR_BY_FOUR() \
+ /* Store q8, q10, q9, q11 */ \
+ /* q8 = d16, d17 */ \
+ "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \
+ /* q10 = d20, d21 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \
+ /* q9 = d18, d19 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \
+ /* q11 = d22, d23 */ \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+ "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
+
+ RUY_STORE_FOUR_BY_FOUR()
+ "cmp r1, r2\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ RUY_STORE_FOUR_BY_FOUR()
+#undef RUY_STORE_FOUR_BY_FOUR
+ "3:\n"
+
+ "ands r2, %[rows], #3\n"
+ "beq 4f\n"
+ "mov r0, #0\n"
+ // Zero out q0 - q3
+ "vdup.32 q0, r0\n"
+ "vdup.32 q1, r0\n"
+ "vdup.32 q2, r0\n"
+ "vdup.32 q3, r0\n"
+#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \
+ "cmp r2, #" #R "\n" \
+ "beq 5f\n" \
+ "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
+ "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
+ "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
+ "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
+
+#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \
+ "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
+ "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
+ "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
+ "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
+
+ RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
+ RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
+ RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
+ RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
+#undef RUY_LOAD_ONE_ROW_SECOND_HALF
+#undef RUY_LOAD_ONE_ROW_FIRST_HALF
+ "5:\n"
+
+ // Transpose 4x4 matrix.
+ "vzip.32 q0, q1\n"
+ "vzip.32 q2, q3\n"
+
+ "vtrn.32 q0, q2\n"
+ "vtrn.32 q1, q3\n"
+
+ "vzip.32 q0, q2\n"
+ "vzip.32 q1, q3\n"
+
+ "vmov q8, q0\n"
+ "vmov q9, q1\n"
+ "vmov q10, q2\n"
+ "vmov q11, q3\n"
+
+ "mov r1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp r2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \
+ "add %[packed_ptr], %[packed_ptr], %[stride]\n"
+
+ // Store q8
+ RUY_STORE_ONE_ROW(0, q8)
+ // Store q10
+ RUY_STORE_ONE_ROW(1, q10)
+ // Store q9
+ RUY_STORE_ONE_ROW(2, q9)
+ // Store q11
+ RUY_STORE_ONE_ROW(3, q11)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+ : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
+ [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
+ [ packed_ptr ] "+r"(packed_ptr)
+ : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
+ [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
+ : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
+ "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
+}
+
+#endif // (RUY_PLATFORM(NEON_32)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col) {
+ profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
+
+ asm volatile(
+ // clang-format off
+ "mov w1, #0\n"
+
+ "and w2, %w[rows], #-4\n"
+ "cmp w1, w2\n"
+ "beq 3f\n"
+ "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
+ "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
+ "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
+ "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
+ "add w1, w1, #4\n"
+ "cmp w1, w2\n"
+
+ "beq 2f\n"
+
+ "1:\n"
+ "add w1, w1, #4\n"
+
+ "ldr x10, [%[src_ptr0], #8]\n"
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
+ "ldr x11, [%[src_ptr1], #8]\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
+ "ldr x12, [%[src_ptr2], #8]\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
+ "ldr x13, [%[src_ptr3], #8]\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+ RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
+
+ "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+ "cmp w1, w2\n"
+
+ "ins v0.d[1], x10\n"
+ "str q20, [%[packed_ptr], #0]\n"
+ "ins v1.d[1], x11\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "ins v2.d[1], x12\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "ins v3.d[1], x13\n"
+ "str q23, [%[packed_ptr], #96]\n"
+
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "bne 1b\n"
+
+ "2:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "str q20, [%[packed_ptr], #0]\n"
+ "str q21, [%[packed_ptr], #32]\n"
+ "str q22, [%[packed_ptr], #64]\n"
+ "str q23, [%[packed_ptr], #96]\n"
+ "add %[packed_ptr], %[packed_ptr], #128\n"
+
+ "3:\n"
+
+ "ands w2, %w[rows], #3\n"
+ "beq 4f\n"
+ "dup v0.16b, wzr\n"
+ "dup v1.16b, wzr\n"
+ "dup v2.16b, wzr\n"
+ "dup v3.16b, wzr\n"
+#define RUY_LOAD_ONE_ROW(R) \
+ "cmp w2, #" #R "\n" \
+ "beq 5f\n" \
+ "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
+ "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
+ "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
+ "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
+
+ RUY_LOAD_ONE_ROW(0)
+ RUY_LOAD_ONE_ROW(1)
+ RUY_LOAD_ONE_ROW(2)
+ RUY_LOAD_ONE_ROW(3)
+#undef RUY_LOAD_ONE_ROW
+ "5:\n"
+
+ "trn1 v16.4s, v0.4s, v1.4s\n"
+ "trn2 v17.4s, v0.4s, v1.4s\n"
+ "trn1 v18.4s, v2.4s, v3.4s\n"
+ "trn2 v19.4s, v2.4s, v3.4s\n"
+
+ "trn1 v20.2d, v16.2d, v18.2d\n"
+ "trn2 v22.2d, v16.2d, v18.2d\n"
+ "trn1 v21.2d, v17.2d, v19.2d\n"
+ "trn2 v23.2d, v17.2d, v19.2d\n"
+
+ "mov x1, #32\n"
+
+#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
+ "cmp w2, #" #ROW "\n" \
+ "beq 4f\n" \
+ "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
+
+ RUY_STORE_ONE_ROW(0, v20)
+ RUY_STORE_ONE_ROW(1, v21)
+ RUY_STORE_ONE_ROW(2, v22)
+ RUY_STORE_ONE_ROW(3, v23)
+
+#undef RUY_STORE_ONE_ROW
+
+ "4:\n"
+
+ // clang-format on
+
+ : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
+ [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
+ : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
+ [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
+ : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
+ "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+}
+#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h
new file mode 100644
index 0000000..8e7f619
--- /dev/null
+++ b/ruy/pack_arm.h
@@ -0,0 +1,497 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_
+
+#include <cstdint>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows,
+ int src_zero_point, std::int8_t* packed_ptr,
+ int start_col, int end_col,
+ std::int32_t* sums_ptr, int input_xor);
+void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ std::int8_t* packed_ptr, int start_col,
+ int end_col, std::int32_t* sums_ptr,
+ int input_xor);
+
+#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params);
+void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params);
+#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \
+ RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+template <typename Scalar>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
+ std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 4, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ const Scalar* src_ptr2 = src_ptr1 + src_stride;
+ const Scalar* src_ptr3 = src_ptr2 + src_stride;
+ int src_inc0 = 16;
+ int src_inc1 = 16;
+ int src_inc2 = 16;
+ int src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * block_col;
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+#if RUY_PLATFORM(NEON_64)
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ Pack8bitNeonInOrder(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+ } else {
+ Pack8bitNeonOutOfOrder(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+ }
+#else
+ // We have a more limited set of general purpose registers in ARMv7, so
+ // we use the "params" struct technique from the kernel code to save
+ // registers.
+ PackParams8bit params;
+ MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr,
+ packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ kInputXor, &params);
+ Pack8bitNeonOutOfOrder4Cols(params);
+#endif // RUY_PLATFORM(NEON_64)
+ }
+ }
+};
+
+#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) &&
+ // RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional
+// partial specialization for the RHS, which has a FixedKernelLayout with 2
+// columns.
+template <typename Scalar>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
+ std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 2, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 2) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ int src_inc0 = 16;
+ int src_inc1 = 16;
+ if (block_col >= src_matrix.layout.cols - 2) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * block_col;
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ PackParams8bit params;
+ MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr,
+ packed_ptr, src_inc0, src_inc1, -1, -1,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ kInputXor, &params);
+ Pack8bitNeonOutOfOrder2Cols(params);
+ }
+ }
+};
+#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+template <typename Scalar>
+struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ static constexpr int kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 8, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[16];
+ memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const Scalar* src_ptr1 = src_ptr0 + src_stride;
+ const Scalar* src_ptr2 = src_ptr1 + src_stride;
+ const Scalar* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & ~7) +
+ ((block_col & 4) * 4);
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ Pack8bitNeonDotprodInOrder(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+ } else {
+ Pack8bitNeonDotprodOutOfOrder(
+ src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
+ src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, sums_ptr, kInputXor);
+ }
+ }
+ }
+};
+#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col);
+void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc0, int src_inc1, int src_inc2,
+ int src_inc3, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col);
+
+#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
+ const float* src_ptr2, const float* src_ptr3,
+ int src_inc, int src_rows, int src_zero_point,
+ float* packed_ptr, int start_col, int end_col,
+ int stride);
+#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \
+ RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float> {
+ static void Run(Tuning tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 8, 0);
+ const float zerobuf[4] = {0};
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ float* packed_ptr = packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & ~7) +
+ ((block_col & 4));
+#if RUY_PLATFORM(NEON_64)
+ if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
+ PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0,
+ src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col);
+ } else {
+ PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
+ src_inc0, src_inc1, src_inc2, src_inc3,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col);
+ }
+#else
+ // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
+ // to save on registers (we have fewer general purpose registers in
+ // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
+ // values that are each either 16 or 0 and use them directly. For the
+ // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
+ // use the value 16 (bit is set) or 0 (bit is not set) for the
+ // respective increment value.
+ std::int64_t src_inc = 0;
+ src_inc += src_inc0 == 16 ? 1 : 0;
+ src_inc += src_inc1 == 16 ? 2 : 0;
+ src_inc += src_inc2 == 16 ? 4 : 0;
+ src_inc += src_inc3 == 16 ? 8 : 0;
+ const int kOutputStride = 32;
+ PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, kOutputStride);
+#endif // RUY_PLATFORM(NEON_64)
+ }
+ }
+};
+
+#if RUY_PLATFORM(NEON_32)
+// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
+// specialization for a FixedKernelLayout with 4 columns.
+template <>
+struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
+ float, float> {
+ static void Run(Tuning tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ(start_col % 4, 0);
+ const float zerobuf[4] = {0};
+ for (int block_col = start_col; block_col < end_col; block_col += 4) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ std::int64_t src_inc0 = 16;
+ std::int64_t src_inc1 = 16;
+ std::int64_t src_inc2 = 16;
+ std::int64_t src_inc3 = 16;
+ if (block_col >= src_matrix.layout.cols - 3) {
+ if (block_col >= src_matrix.layout.cols - 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (block_col >= src_matrix.layout.cols - 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ }
+ float* packed_ptr =
+ packed_matrix->data + packed_matrix->layout.stride * (block_col);
+ // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
+ // to save registers.
+ std::int64_t src_inc = 0;
+ src_inc += src_inc0 == 16 ? 1 : 0;
+ src_inc += src_inc1 == 16 ? 2 : 0;
+ src_inc += src_inc2 == 16 ? 4 : 0;
+ src_inc += src_inc3 == 16 ? 8 : 0;
+ const int kOutputStride = 16;
+ PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
+ src_matrix.layout.rows, src_matrix.zero_point,
+ packed_ptr, start_col, end_col, kOutputStride);
+ }
+ }
+};
+#endif // (RUY_PLATFORM(NEON_32))
+#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \
+ // RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_
diff --git a/ruy/pack_avx2.cc b/ruy/pack_avx2.cc
new file mode 100644
index 0000000..013a8c0
--- /dev/null
+++ b/ruy/pack_avx2.cc
@@ -0,0 +1,816 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx2 =
+ PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ std::int8_t, std::int8_t, std::int32_t>;
+
+using PackImplFloatAvx2 =
+ PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float>;
+
+namespace {
+
+inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point,
+ const std::int8_t* addr) {
+ RUY_DCHECK_LT(available_src_rows, 32);
+ __m256i padded_data;
+
+ if (available_src_rows >= 16) {
+ __m128i load_hi = _mm_set1_epi8(zero_point);
+ __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr));
+ memcpy(&load_hi, addr + 16, available_src_rows - 16);
+ padded_data = _mm256_set_m128i(load_hi, load_lo);
+ } else {
+ __m128i load_hi = _mm_set1_epi8(zero_point);
+ __m128i load_lo = load_hi;
+ memcpy(&load_lo, addr, available_src_rows);
+ padded_data = _mm256_set_m128i(load_hi, load_lo);
+ }
+ return padded_data;
+}
+
+inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitAvx2::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have Layout::kCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: Layout::kCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m256i ones_16bit = _mm256_set1_epi16(1);
+ __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
+ __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ if (sums_ptr) {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo =
+ _mm256_add_epi16(sums_4x4_16bit_lo,
+ _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ } else {
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
+ t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
+ t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
+ t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
+ t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
+ t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
+ t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
+ t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
+ r7);
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset, effectively
+ // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ //
+ // Note that (zero_point ^ input_xor) is performed in 8-bits and then
+ // cast.
+ sums_adjustment +=
+ -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
+
+ __m256i t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256i r0, r1, r2, r3, r4, r5, r6, r7;
+ const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
+
+ t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0);
+ t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4);
+ t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1);
+ t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5);
+ t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2);
+ t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6);
+ t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3);
+ t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7);
+
+ r0 = _mm256_unpacklo_epi32(t0, t1);
+ r4 = _mm256_unpacklo_epi32(t4, t5);
+ r2 = _mm256_unpackhi_epi32(t0, t1);
+ r6 = _mm256_unpackhi_epi32(t4, t5);
+ r1 = _mm256_unpacklo_epi32(t2, t3);
+ r5 = _mm256_unpacklo_epi32(t6, t7);
+ r3 = _mm256_unpackhi_epi32(t2, t3);
+ r7 = _mm256_unpackhi_epi32(t6, t7);
+
+ t0 = _mm256_unpacklo_epi64(r0, r1);
+ t4 = _mm256_unpacklo_epi64(r4, r5);
+ t2 = _mm256_unpackhi_epi64(r0, r1);
+ t6 = _mm256_unpackhi_epi64(r4, r5);
+ t1 = _mm256_unpacklo_epi64(r2, r3);
+ t5 = _mm256_unpacklo_epi64(r6, r7);
+ t3 = _mm256_unpackhi_epi64(r2, r3);
+ t7 = _mm256_unpackhi_epi64(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by
+ // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
+ // t4) are interleaved to create (r0, r1). This complexity follows from
+ // the way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
+ r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
+ r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
+ r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
+ r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
+ r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
+ r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
+ r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
+
+ r0 = _mm256_xor_si256(r0, input_xor_v);
+ r1 = _mm256_xor_si256(r1, input_xor_v);
+ r2 = _mm256_xor_si256(r2, input_xor_v);
+ r3 = _mm256_xor_si256(r3, input_xor_v);
+ r4 = _mm256_xor_si256(r4, input_xor_v);
+ r5 = _mm256_xor_si256(r5, input_xor_v);
+ r6 = _mm256_xor_si256(r6, input_xor_v);
+ r7 = _mm256_xor_si256(r7, input_xor_v);
+
+ __m256i sums_4x4_16bit_lo;
+ sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
+ sums_4x4_16bit_lo = _mm256_add_epi16(
+ sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
+
+ // The sums have been performed across columns, and now we have 4x16-bit
+ // sums packed together. We use madd for pairwise 32-bit sums.
+ const __m256i sums_4x2_32bit_lo_new =
+ _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
+ sums_4x2_32bit_lo =
+ _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
+
+ __m256i sums_4x4_16bit_hi;
+ sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
+ sums_4x4_16bit_hi = _mm256_add_epi16(
+ sums_4x4_16bit_hi,
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
+
+ const __m256i sums_4x2_32bit_hi_new =
+ _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
+ sums_4x2_32bit_hi =
+ _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
+ r0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
+ r4);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
+ r1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
+ r5);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
+ r2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
+ r6);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
+ r3);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
+ r7);
+ }
+
+ packed_ptr += 8 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
+ const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+ // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
+ // neighbours, finshing up by adding them to the stored accumulated sums.
+ const __m256i sums_2x4_32bit_lo =
+ _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx);
+ const __m256i sums_2x4_32bit_hi =
+ _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx);
+ const __m256i sums_2x4_32bit_a =
+ _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
+ const __m256i sums_2x4_32bit_b =
+ _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
+ sums = _mm256_add_epi32(sums, sums_adjustment_v);
+ sums = _mm256_add_epi32(sums, sums_2x4_32bit_a);
+ sums = _mm256_add_epi32(sums, sums_2x4_32bit_b);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
+ }
+}
+
+inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) {
+ return _mm256_castpd_ps(
+ _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
+}
+
+inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) {
+ return _mm256_castpd_ps(
+ _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
+}
+
+inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8);
+ RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1);
+
+ // This packing amounts to transposition of 8x8 blocks.
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ // Handle cases where source does not have kPackDim (8) columns.
+ if (remaining_src_cols < kPackCols) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += kPackRows) {
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
+ // but treat each case separately.
+ if (available_src_rows >= kPackRows) {
+ __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256 r0, r1, r2, r3, r4, r5, r6, r7;
+
+ t0 = _mm256_loadu_ps(src_ptr0);
+ t4 = _mm256_loadu_ps(src_ptr4);
+ t1 = _mm256_loadu_ps(src_ptr1);
+ t5 = _mm256_loadu_ps(src_ptr5);
+ t2 = _mm256_loadu_ps(src_ptr2);
+ t6 = _mm256_loadu_ps(src_ptr6);
+ t3 = _mm256_loadu_ps(src_ptr3);
+ t7 = _mm256_loadu_ps(src_ptr7);
+
+ r0 = _mm256_unpacklo_ps(t0, t1);
+ r4 = _mm256_unpacklo_ps(t4, t5);
+ r2 = _mm256_unpackhi_ps(t0, t1);
+ r6 = _mm256_unpackhi_ps(t4, t5);
+ r1 = _mm256_unpacklo_ps(t2, t3);
+ r5 = _mm256_unpacklo_ps(t6, t7);
+ r3 = _mm256_unpackhi_ps(t2, t3);
+ r7 = _mm256_unpackhi_ps(t6, t7);
+
+ t0 = Mm256UnpackloPsx2(r0, r1);
+ t4 = Mm256UnpackloPsx2(r4, r5);
+ t2 = Mm256UnpackhiPsx2(r0, r1);
+ t6 = Mm256UnpackhiPsx2(r4, r5);
+ t1 = Mm256UnpackloPsx2(r2, r3);
+ t5 = Mm256UnpackloPsx2(r6, r7);
+ t3 = Mm256UnpackhiPsx2(r2, r3);
+ t7 = Mm256UnpackhiPsx2(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by 16
+ // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
+ // are interleaved to create (r0, r1). This complexity follows from the
+ // way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
+ r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
+ r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
+ r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
+ r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
+ r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
+ r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
+ r7 = _mm256_permute2f128_ps(t3, t7, 0x31);
+
+ _mm256_storeu_ps(packed_ptr + 0 * 8, r0);
+ _mm256_storeu_ps(packed_ptr + 2 * 8, r4);
+ _mm256_storeu_ps(packed_ptr + 4 * 8, r1);
+ _mm256_storeu_ps(packed_ptr + 6 * 8, r5);
+ _mm256_storeu_ps(packed_ptr + 1 * 8, r2);
+ _mm256_storeu_ps(packed_ptr + 3 * 8, r6);
+ _mm256_storeu_ps(packed_ptr + 5 * 8, r3);
+ _mm256_storeu_ps(packed_ptr + 7 * 8, r7);
+ } else if (available_src_rows > 0) {
+ const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
+ const __m256i row_mask_v =
+ _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series);
+
+ __m256 t0, t1, t2, t3, t4, t5, t6, t7;
+ __m256 r0, r1, r2, r3, r4, r5, r6, r7;
+
+ t0 = _mm256_maskload_ps(src_ptr0, row_mask_v);
+ t4 = _mm256_maskload_ps(src_ptr4, row_mask_v);
+ t1 = _mm256_maskload_ps(src_ptr1, row_mask_v);
+ t5 = _mm256_maskload_ps(src_ptr5, row_mask_v);
+ t2 = _mm256_maskload_ps(src_ptr2, row_mask_v);
+ t6 = _mm256_maskload_ps(src_ptr6, row_mask_v);
+ t3 = _mm256_maskload_ps(src_ptr3, row_mask_v);
+ t7 = _mm256_maskload_ps(src_ptr7, row_mask_v);
+
+ r0 = _mm256_unpacklo_ps(t0, t1);
+ r4 = _mm256_unpacklo_ps(t4, t5);
+ r2 = _mm256_unpackhi_ps(t0, t1);
+ r6 = _mm256_unpackhi_ps(t4, t5);
+ r1 = _mm256_unpacklo_ps(t2, t3);
+ r5 = _mm256_unpacklo_ps(t6, t7);
+ r3 = _mm256_unpackhi_ps(t2, t3);
+ r7 = _mm256_unpackhi_ps(t6, t7);
+
+ t0 = Mm256UnpackloPsx2(r0, r1);
+ t4 = Mm256UnpackloPsx2(r4, r5);
+ t2 = Mm256UnpackhiPsx2(r0, r1);
+ t6 = Mm256UnpackhiPsx2(r4, r5);
+ t1 = Mm256UnpackloPsx2(r2, r3);
+ t5 = Mm256UnpackloPsx2(r6, r7);
+ t3 = Mm256UnpackhiPsx2(r2, r3);
+ t7 = Mm256UnpackhiPsx2(r6, r7);
+
+ // The preceding sets of rearrangement operations interleaved by 4 bytes
+ // and then by 8 bytes *within* lanes. The following set interleave by 16
+ // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
+ // are interleaved to create (r0, r1). This complexity follows from the
+ // way that AVX is centered around MM 128-bit lanes.
+ r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
+ r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
+ r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
+ r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
+ r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
+ r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
+ r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
+ // r7 no longer needed.
+
+ _mm256_storeu_ps(trailing_buf + 0 * 8, r0);
+ _mm256_storeu_ps(trailing_buf + 2 * 8, r4);
+ _mm256_storeu_ps(trailing_buf + 4 * 8, r1);
+ _mm256_storeu_ps(trailing_buf + 6 * 8, r5);
+ _mm256_storeu_ps(trailing_buf + 1 * 8, r2);
+ _mm256_storeu_ps(trailing_buf + 3 * 8, r6);
+ _mm256_storeu_ps(trailing_buf + 5 * 8, r3);
+ // No store to (trailing_buf + 7 * 8), space not allocated.
+ }
+
+ packed_ptr += kPackRows * kPackCols;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+}
+
+} // namespace.
+
+void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx2 8bit");
+
+ using Layout = PackImpl8bitAvx2::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ static constexpr int kNumRowChunks = 8; // Short input is padded.
+
+ // Each packed block is 4*8, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data > 0) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvx2 float");
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+ float trailing_buf[(kPackRows - 1) * kPackCols];
+ if (remaining_src_cols < 8) {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ }
+ PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+
+ const int trailing_rows = src_rows & (kPackRows - 1);
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~(kPackRows - 1);
+ memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
+ kPackCols * trailing_rows * sizeof(float));
+ }
+}
+
+#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc
new file mode 100644
index 0000000..ecad3a2
--- /dev/null
+++ b/ruy/pack_avx512.cc
@@ -0,0 +1,693 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvx512 =
+ PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ std::int8_t, std::int8_t, std::int32_t>;
+
+namespace {
+
+inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point,
+ std::int8_t* packed_ptr) {
+ using Layout = PackImpl8bitAvx512::Layout;
+ static constexpr int kHalfLayoutCols =
+ PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
+ RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ const int non_trailing_blocks = (src_rows & ~31) >> 2;
+ // This routine fills half blocks, and typically fills the second halves.
+ // Thus packed_ptr is already offset by 8 * 4.
+ for (int k = 0; k < non_trailing_blocks; ++k) {
+ for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
+ packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
+ }
+ }
+}
+
+inline __m512i LoaduTwo(const std::int8_t* addr_lo,
+ const std::int8_t* addr_hi) {
+ __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo));
+ return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1);
+}
+
+inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
+ const std::int8_t* addr_lo,
+ const std::int8_t* addr_hi) {
+ const __m512i lower_filled = _mm512_castsi256_si512(
+ _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
+ return _mm512_inserti32x8(
+ lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
+ 1);
+}
+
+inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitAvx512::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have kHalfLayoutCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: kHalfLayoutCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+ std::int32_t sums_adjustment = 0;
+ const __m512i ones_16bit = _mm512_set1_epi16(1);
+ __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
+ // m: {0, 1} for 2 chunks of rows.
+ for (int m = 0; m < 2; ++m) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ // i: chunks, s: Layout::Rows.
+ if (sums_ptr) {
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+
+ __m512i sums_8x4_16bit;
+ sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
+ // The sums have been performed across columns, and now we have
+ // 4x16-bit sums packed together. We use madd for pairwise 32-bit
+ // sums.
+ const __m512i sums_8x2_32bit_new =
+ _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
+
+ _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0);
+ _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1);
+ _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0);
+ _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1);
+ _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0);
+ _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1);
+ _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0);
+ _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1);
+ } else {
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+ _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0);
+ _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1);
+ _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0);
+ _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1);
+ _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0);
+ _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1);
+ _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0);
+ _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1);
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
+ const __mmask32 row_mask =
+ (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
+
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // We compensate for padding-with-zero_point by initializing the
+ // summations with the compensating offset, effectively
+ // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ //
+ // Note that (zero_point ^ input_xor) is performed in 8-bits and then
+ // cast.
+ sums_adjustment += -(zero_point ^ input_xor) * 4 *
+ (8 - ((available_src_rows + 3) >> 2));
+
+ __m512i t0, t1, t2, t3;
+ __m512i r0, r1, r2, r3;
+ const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
+ const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
+
+ t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
+ t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
+ t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
+ t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_epi32(t0, t1);
+ r2 = _mm512_unpackhi_epi32(t0, t1);
+ r1 = _mm512_unpacklo_epi32(t2, t3);
+ r3 = _mm512_unpackhi_epi32(t2, t3);
+
+ t0 = _mm512_unpacklo_epi64(r0, r1);
+ t2 = _mm512_unpackhi_epi64(r0, r1);
+ t1 = _mm512_unpacklo_epi64(r2, r3);
+ t3 = _mm512_unpackhi_epi64(r2, r3);
+
+ r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
+
+ r0 = _mm512_xor_si512(r0, input_xor_v);
+ r1 = _mm512_xor_si512(r1, input_xor_v);
+ r2 = _mm512_xor_si512(r2, input_xor_v);
+ r3 = _mm512_xor_si512(r3, input_xor_v);
+
+ const __m256i r0_0 = _mm512_castsi512_si256(r0);
+ const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
+ const __m256i r1_0 = _mm512_castsi512_si256(r1);
+ const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
+ const __m256i r2_0 = _mm512_castsi512_si256(r2);
+ const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
+ const __m256i r3_0 = _mm512_castsi512_si256(r3);
+ const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
+
+ __m512i sums_8x4_16bit;
+ sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
+ sums_8x4_16bit =
+ _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
+ // The sums have been performed across columns, and now we have
+ // 4x16-bit sums packed together. We use madd for pairwise 32-bit
+ // sums.
+ const __m512i sums_8x2_32bit_new =
+ _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
+ sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
+
+ _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0);
+ _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1);
+ _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0);
+ _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1);
+ _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0);
+ _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1);
+ _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0);
+ _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1);
+ }
+
+ packed_ptr += 16 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+
+ if (sums_ptr) {
+ const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
+
+ __m256i sums = _mm256_loadu_epi32(sums_ptr);
+ const __m512i idx =
+ _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
+
+ // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
+ // neighbours, finshing up by adding them to the stored accumulated sums.
+ const __m512i sums_2x8_32bit =
+ _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
+ sums = _mm256_add_epi32(sums, sums_adjustment_v);
+ sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
+ sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
+
+ _mm256_storeu_epi32(sums_ptr, sums);
+ }
+}
+
+inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
+ const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
+ return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
+}
+
+inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
+ const float* addr_hi) {
+ const __m512 lower_filled =
+ _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
+ return _mm512_insertf32x8(lower_filled,
+ _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
+}
+
+inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
+ return _mm512_castpd_ps(
+ _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
+ return _mm512_castpd_ps(
+ _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
+}
+
+inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += 16) {
+ for (int m = 0; m < 2; ++m) {
+ const int available_src_rows = src_rows - k - 8 * m;
+ // Effectively,
+ // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
+ // but treat each case separately.
+ if (available_src_rows > 7) {
+ __m512 t0, t1, t2, t3;
+ __m512 r0, r1, r2, r3;
+
+ t0 = LoaduTwo(src_ptr0, src_ptr4);
+ t1 = LoaduTwo(src_ptr1, src_ptr5);
+ t2 = LoaduTwo(src_ptr2, src_ptr6);
+ t3 = LoaduTwo(src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_ps(t0, t1);
+ r2 = _mm512_unpackhi_ps(t0, t1);
+ r1 = _mm512_unpacklo_ps(t2, t3);
+ r3 = _mm512_unpackhi_ps(t2, t3);
+
+ t0 = Mm512UnpackloPsx2(r0, r1);
+ t2 = Mm512UnpackhiPsx2(r0, r1);
+ t1 = Mm512UnpackloPsx2(r2, r3);
+ t3 = Mm512UnpackhiPsx2(r2, r3);
+
+ r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+ _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
+ _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+ _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
+ _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+ _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
+ _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+ _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
+ _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
+ } else if (available_src_rows > 0) {
+ const __mmask8 row_mask =
+ (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
+
+ __m512 t0, t1, t2, t3;
+ __m512 r0, r1, r2, r3;
+
+ t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
+ t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
+ t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
+ t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
+
+ r0 = _mm512_unpacklo_ps(t0, t1);
+ r2 = _mm512_unpackhi_ps(t0, t1);
+ r1 = _mm512_unpacklo_ps(t2, t3);
+ r3 = _mm512_unpackhi_ps(t2, t3);
+
+ t0 = Mm512UnpackloPsx2(r0, r1);
+ t2 = Mm512UnpackhiPsx2(r0, r1);
+ t1 = Mm512UnpackloPsx2(r2, r3);
+ t3 = Mm512UnpackhiPsx2(r2, r3);
+
+ r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
+ r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
+ r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
+ r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
+
+ _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
+ _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
+ _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
+ _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
+ _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
+ _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
+ _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
+ // Do not store _mm512_extractf32x8_ps(r3, 1).
+ }
+
+ packed_ptr += 16 * 8;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+}
+
+inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
+ const int non_trailing_rows = src_rows & ~7;
+ for (int k = 0; k < non_trailing_rows; ++k) {
+ for (int j = 0; j < 8; ++j) {
+ packed_ptr[j] = 0.0f;
+ }
+ packed_ptr += 16;
+ }
+}
+
+} // namespace.
+
+void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvx512 8bit");
+
+ using Layout = PackImpl8bitAvx512::Layout;
+ constexpr int kHalfBlockOffset = 32;
+ RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
+ static constexpr int kHalfLayoutCols =
+ PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
+ // block.
+ RUY_DCHECK_EQ(kHalfLayoutCols, 8);
+ RUY_DCHECK_EQ(Layout::kCols, 16);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+
+ // Each packed block is 4*16, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ std::int32_t* second_sums_ptr =
+ sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
+ if (remaining_src_cols > kHalfLayoutCols) {
+ HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
+ zerobuf, src_stride,
+ remaining_src_cols - kHalfLayoutCols, src_rows,
+ packed_ptr + kHalfBlockOffset, second_sums_ptr,
+ trailing_buf + kHalfBlockOffset);
+ } else {
+ HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor,
+ packed_ptr + kHalfBlockOffset);
+ // The kernel may not need the second half-blocks sums to be set.
+ if (second_sums_ptr) {
+ for (int i = 0; i < kHalfLayoutCols; ++i) {
+ second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
+ }
+ }
+ }
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data > 0) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvx512 float");
+ float trailing_buf[7 * 16];
+ if (remaining_src_cols > 8) {
+ HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
+ remaining_src_cols - 8, src_rows, packed_ptr + 8,
+ trailing_buf + 8);
+ } else {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
+ }
+ const int trailing_rows = src_rows & 7;
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~7;
+ memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
+ 16 * trailing_rows * sizeof(float));
+ }
+}
+
+#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_avxvnni.cc b/ruy/pack_avxvnni.cc
new file mode 100644
index 0000000..bb9a730
--- /dev/null
+++ b/ruy/pack_avxvnni.cc
@@ -0,0 +1,478 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols, int src_rows,
+ float* packed_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitAvxVnni =
+ PackImpl<Path::kAvxVnni, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ std::int8_t, std::int8_t, std::int32_t>;
+
+namespace {
+
+inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point,
+ std::int8_t* packed_ptr) {
+ const int non_trailing_blocks = (src_rows & ~31) >> 2;
+ // This routine fills half blocks, and typically fills the second halves. Thus
+ // packed_ptr is already offset by 8*4.
+ for (int k = 0; k < non_trailing_blocks; ++k) {
+ for (int j = 0; j < (8 * 4); ++j) {
+ packed_ptr[16 * 4 * k + j] = packed_zero_point;
+ }
+ }
+}
+
+inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ std::int8_t in_data[8][8][4];
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8 * 4;
+ std::int64_t src_inc1 = 8 * 4;
+ std::int64_t src_inc2 = 8 * 4;
+ std::int64_t src_inc3 = 8 * 4;
+ std::int64_t src_inc4 = 8 * 4;
+ std::int64_t src_inc5 = 8 * 4;
+ std::int64_t src_inc6 = 8 * 4;
+ std::int64_t src_inc7 = 8 * 4;
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += 16 * 4) {
+ for (int m = 0; m < 2; ++m) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int packed_rows = src_rows - k - 8 * m * 4;
+ // Effectively,
+ // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
+ // but treat each case separately.
+ if (packed_rows >= (8 * 4)) {
+ for (int i = 0; i < 8; ++i) {
+ for (int s = 0; s < 4; ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ }
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ packed_ptr[(16 * i + j) * 4 + s] =
+ static_cast<std::int8_t>(in_data[j][i][s] ^ input_xor);
+ }
+ if (sums_ptr) {
+ for (int s = 0; s < 4; ++s) {
+ sums_ptr[j] += in_data[j][i][s] ^ input_xor;
+ }
+ }
+ }
+ }
+ } else if (packed_rows > 0) {
+ RUY_DCHECK_LT(packed_rows >> 2, 8);
+ int i = 0;
+ for (; i < (packed_rows >> 2); ++i) {
+ for (int s = 0; s < 4; ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ }
+ if (i < ((packed_rows + 3) >> 2)) {
+ int s = 0;
+ for (; s < (packed_rows & 3); ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ RUY_DCHECK_LE(s, 4);
+ for (; s < 4; ++s) {
+ for (int j = 0; j < 8; ++j) {
+ in_data[j][i][s] = zero_point;
+ }
+ }
+ ++i;
+ }
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // It might prove better in optimized code to pad uniformly with
+ // zero_point, and compensate by initializing the summations with the
+ // compensating offset, effectively
+ // ((input_xor - zero_point) ^ input_xor) *
+ // 4 * (8 - ((packed_rows + 3) >> 2)).
+ for (; i < 8; ++i) {
+ for (int s = 0; s < 4; ++s) {
+ for (int j = 0; j < 8; ++j) {
+ in_data[j][i][s] = input_xor;
+ }
+ }
+ }
+ // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since
+ // that emulates what we might do in fully-optimized code.
+ if (sums_ptr) {
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ trailing_buf[(16 * i + j) * 4 + s] =
+ static_cast<std::int8_t>(in_data[j][i][s] ^ input_xor);
+ sums_ptr[j] += in_data[j][i][s] ^ input_xor;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ trailing_buf[(16 * i + j) * 4 + s] =
+ static_cast<std::int8_t>(in_data[j][i][s] ^ input_xor);
+ }
+ }
+ }
+ }
+ }
+
+ packed_ptr += 16 * 8 * 4;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+}
+
+inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ float in_data[8][8];
+
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += 16) {
+ for (int m = 0; m < 2; ++m) {
+ const int packed_rows = src_rows - k - 8 * m;
+ // Effectively,
+ // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
+ // but treat each case separately.
+ if (packed_rows > 7) {
+ for (int i = 0; i < 8; ++i) {
+ in_data[0][i] = src_ptr0[i];
+ in_data[1][i] = src_ptr1[i];
+ in_data[2][i] = src_ptr2[i];
+ in_data[3][i] = src_ptr3[i];
+ in_data[4][i] = src_ptr4[i];
+ in_data[5][i] = src_ptr5[i];
+ in_data[6][i] = src_ptr6[i];
+ in_data[7][i] = src_ptr7[i];
+ }
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ packed_ptr[16 * i + j] = in_data[j][i];
+ }
+ }
+ } else if (packed_rows > 0) {
+ for (int i = 0; i < packed_rows; ++i) {
+ in_data[0][i] = src_ptr0[i];
+ in_data[1][i] = src_ptr1[i];
+ in_data[2][i] = src_ptr2[i];
+ in_data[3][i] = src_ptr3[i];
+ in_data[4][i] = src_ptr4[i];
+ in_data[5][i] = src_ptr5[i];
+ in_data[6][i] = src_ptr6[i];
+ in_data[7][i] = src_ptr7[i];
+ }
+ for (int i = packed_rows; i < 8; ++i) {
+ in_data[0][i] = 0.0f;
+ in_data[1][i] = 0.0f;
+ in_data[2][i] = 0.0f;
+ in_data[3][i] = 0.0f;
+ in_data[4][i] = 0.0f;
+ in_data[5][i] = 0.0f;
+ in_data[6][i] = 0.0f;
+ in_data[7][i] = 0.0f;
+ }
+ // We loop through [0, 7) rather than [0, packed_rows), since that
+ // emulates what we might do in fully-optimized code.
+ for (int i = 0; i < 7; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ trailing_buf[16 * i + j] = in_data[j][i];
+ }
+ }
+ }
+
+ packed_ptr += 16 * 8;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+ }
+}
+
+inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) {
+ const int non_trailing_rows = src_rows & ~7;
+ for (int k = 0; k < non_trailing_rows; ++k) {
+ for (int j = 0; j < 8; ++j) {
+ packed_ptr[j] = 0.0f;
+ }
+ packed_ptr += 16;
+ }
+}
+
+} // namespace.
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)");
+
+ // Each packed block is 4*16, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ std::int8_t trailing_buf[8 * 16 * 4];
+ memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t));
+
+ std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr;
+ if (remaining_src_cols > 8) {
+ HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf,
+ src_stride, remaining_src_cols - 8, src_rows,
+ packed_ptr + 8 * 4, second_sums_ptr,
+ trailing_buf + 8 * 4);
+ } else {
+ HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+ ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4);
+ // The kernel may not need the second half-blocks sums to be set.
+ if (second_sums_ptr) {
+ for (int i = 0; i < 8; ++i) {
+ second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
+ }
+ }
+ }
+ const bool trailing_data = (src_rows & 31) > 0;
+ // If the number of source rows is not a multiple of 32, there will be data in
+ // the trailing buffer,
+ if (trailing_data > 0) {
+ const int non_trailing_rows = src_rows & ~31;
+ // Destination "rows" are padded to next highest multiple of 4.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
+ 16 * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols, int src_rows,
+ float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)");
+ float trailing_buf[7 * 16];
+ if (remaining_src_cols > 8) {
+ HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride,
+ remaining_src_cols - 8, src_rows, packed_ptr + 8,
+ trailing_buf + 8);
+ } else {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+ ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8);
+ }
+ const int trailing_rows = src_rows & 7;
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~7;
+ memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
+ 16 * trailing_rows * sizeof(float));
+ }
+}
+
+#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_common.h b/ruy/pack_common.h
new file mode 100644
index 0000000..5c03afd
--- /dev/null
+++ b/ruy/pack_common.h
@@ -0,0 +1,246 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_
+
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+template <Path ThePath, typename Scalar>
+struct PackedTypeImpl {
+ using Type = Scalar;
+};
+
+#if RUY_PLATFORM(NEON_32)
+struct PackParams8bit {
+ const void* src_ptr0;
+ const void* src_ptr1;
+ const void* src_ptr2;
+ const void* src_ptr3;
+ const std::int32_t* sums_ptr;
+ const std::int8_t* packed_ptr;
+ int src_inc0;
+ int src_inc1;
+ int src_inc2;
+ int src_inc3;
+ int src_rows;
+ int src_zero_point;
+ int input_xor;
+};
+
+inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1,
+ const void* src_ptr2, const void* src_ptr3,
+ const std::int32_t* sums_ptr,
+ const std::int8_t* packed_ptr, int src_inc0,
+ int src_inc1, int src_inc2, int src_inc3,
+ int src_rows, int src_zero_point, int input_xor,
+ PackParams8bit* params) {
+ params->src_ptr0 = src_ptr0;
+ params->src_ptr1 = src_ptr1;
+ params->src_ptr2 = src_ptr2;
+ params->src_ptr3 = src_ptr3;
+ params->sums_ptr = sums_ptr;
+ params->packed_ptr = packed_ptr;
+ params->src_inc0 = src_inc0;
+ params->src_inc1 = src_inc1;
+ params->src_inc2 = src_inc2;
+ params->src_inc3 = src_inc3;
+ params->src_rows = src_rows;
+ params->src_zero_point = src_zero_point;
+ params->input_xor = input_xor;
+}
+#endif
+
+#if RUY_PLATFORM(NEON)
+template <>
+struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
+ using Type = std::int8_t;
+};
+#elif RUY_PLATFORM(X86)
+template <>
+struct PackedTypeImpl<Path::kSse42, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kAvx2, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
+ using Type = std::int8_t;
+};
+template <>
+struct PackedTypeImpl<Path::kAvxVnni, std::uint8_t> {
+ using Type = std::int8_t;
+};
+#endif
+
+template <Path ThePath, typename Scalar>
+using PackedType = typename PackedTypeImpl<ThePath, Scalar>::Type;
+
+template <typename PackedScalar, typename Scalar>
+PackedScalar Pack(Scalar x) {
+ return x - SymmetricZeroPoint<Scalar>() + SymmetricZeroPoint<PackedScalar>();
+}
+
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+ typename PackedScalar, typename SumsType>
+struct PackImpl {};
+
+#define RUY_INHERIT_PACK(PARENT, CHILD) \
+ template <typename FixedKernelLayout, typename Scalar, \
+ typename PackedScalar, typename SumsType> \
+ struct PackImpl<CHILD, FixedKernelLayout, Scalar, PackedScalar, SumsType> \
+ : PackImpl<PARENT, FixedKernelLayout, Scalar, PackedScalar, SumsType> { \
+ };
+
+template <typename FixedKernelLayout, typename Scalar, typename PackedScalar,
+ typename SumsType>
+struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
+ SumsType> {
+ static void Run(Tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<PackedScalar>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (generic)");
+ RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0);
+ SumsType* sums = packed_matrix->sums;
+ for (int col = start_col; col < end_col; col++) {
+ SumsType accum = 0;
+ for (int row = 0; row < packed_matrix->layout.rows; row++) {
+ PackedScalar packed_val;
+ if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) {
+ packed_val = Pack<PackedScalar>(Element(src_matrix, row, col));
+ } else {
+ packed_val = packed_matrix->zero_point;
+ }
+ accum += packed_val;
+ *ElementPtr(packed_matrix, row, col) = packed_val;
+ }
+ if (sums) {
+ sums[col] = accum;
+ }
+ }
+ }
+};
+
+#if RUY_PLATFORM(NEON)
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
+RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
+#elif RUY_PLATFORM(X86)
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42)
+RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2)
+RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512)
+RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni)
+#endif
+
+// Main entry point for packing.
+template <Path ThePath, typename FixedKernelLayout, typename Scalar,
+ typename PackedScalar>
+void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix,
+ int start_col, int end_col) {
+ using SumsType = typename PackedMatrix<PackedScalar>::SumsType;
+ Matrix<Scalar> src = ToMatrix<Scalar>(src_matrix);
+ PackedMatrix<PackedScalar> packed =
+ ToPackedMatrix<PackedScalar>(*packed_matrix);
+ PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType>::Run(
+ tuning, src, &packed, start_col, end_col);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_
diff --git a/ruy/pack_sse42.cc b/ruy/pack_sse42.cc
new file mode 100644
index 0000000..90c7250
--- /dev/null
+++ b/ruy/pack_sse42.cc
@@ -0,0 +1,471 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <cstring>
+
+#include "ruy/check_macros.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+// The first int8_t template parameter is arbitrary: this routine is common to
+// all 8-bit source matrix types.
+using PackImpl8bitSse42 =
+ PackImpl<Path::kSse42, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ std::int8_t, std::int8_t, std::int32_t>;
+
+using PackImplFloatSse42 =
+ PackImpl<Path::kSse42, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float>;
+
+namespace {
+
+inline void Pack8bitSse42Packer(const std::int8_t* src_ptr,
+ std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr,
+ std::int8_t* trailing_buf) {
+ using Layout = PackImpl8bitSse42::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ constexpr int kNumRowChunks = 8;
+ constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
+
+ std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows];
+
+ const std::int8_t* src_ptr0 = src_ptr;
+ const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
+ const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
+ const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
+ const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
+ const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
+ const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
+ const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = kNumChunkedSrcRows;
+ std::int64_t src_inc1 = kNumChunkedSrcRows;
+ std::int64_t src_inc2 = kNumChunkedSrcRows;
+ std::int64_t src_inc3 = kNumChunkedSrcRows;
+ std::int64_t src_inc4 = kNumChunkedSrcRows;
+ std::int64_t src_inc5 = kNumChunkedSrcRows;
+ std::int64_t src_inc6 = kNumChunkedSrcRows;
+ std::int64_t src_inc7 = kNumChunkedSrcRows;
+ // Handle cases where source does not have Layout::kCols (8) columns.
+ if (remaining_src_cols < 8) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ const std::int8_t zero_point = zerobuf[0];
+
+ if (sums_ptr) {
+ // i: Layout::kCols.
+ for (int i = 0; i < 8; ++i) {
+ sums_ptr[i] = 0;
+ }
+ }
+
+ // The overall packing effectively pads the source rows to
+ // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
+ // only pack for (src_rows + 31) & ~31. When there is an incomplete
+ // destination block, this is stored into trailing_buf instead of packed_ptr.
+ for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
+ // Available source rows.
+ // If this is less than 0 (for m=1), we skip, having filled trailing
+ // buffer for m=0. Also, if source rows is zero on m=1, then we filled
+ // exactly to the end of the column in the packed buffer.
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available rows = std::max(0, std::min(8, src_rows - k));
+ // treat each case separately.
+ if (available_src_rows >= kNumChunkedSrcRows) {
+ // i: chunks, s: Layout::Rows.
+ for (int i = 0; i < 8; ++i) {
+ for (int s = 0; s < 4; ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ }
+ // i: chunks, j: Layout::kCols, s: Layout::Rows.
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ // 8 * 4 * i is offset for each block, that is
+ // (Layout::kCols * Layout::kRows * i)
+ packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
+ }
+ if (sums_ptr) {
+ for (int s = 0; s < 4; ++s) {
+ sums_ptr[j] += in_data[j][i][s] ^ input_xor;
+ }
+ }
+ }
+ }
+ } else if (available_src_rows > 0) {
+ RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
+ int i = 0;
+ // Consume chunks of 4 rows that are complete.
+ for (; i < (available_src_rows >> 2); ++i) {
+ for (int s = 0; s < 4; ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ }
+ // Consume any incomplete chunk.
+ if (i < ((available_src_rows + 3) >> 2)) {
+ int s = 0;
+ for (; s < (available_src_rows & 3); ++s) {
+ in_data[0][i][s] = src_ptr0[i * 4 + s];
+ in_data[1][i][s] = src_ptr1[i * 4 + s];
+ in_data[2][i][s] = src_ptr2[i * 4 + s];
+ in_data[3][i][s] = src_ptr3[i * 4 + s];
+ in_data[4][i][s] = src_ptr4[i * 4 + s];
+ in_data[5][i][s] = src_ptr5[i * 4 + s];
+ in_data[6][i][s] = src_ptr6[i * 4 + s];
+ in_data[7][i][s] = src_ptr7[i * 4 + s];
+ }
+ RUY_DCHECK_LE(s, 4);
+ for (; s < 4; ++s) {
+ // j: Layout::kCols.
+ for (int j = 0; j < 8; ++j) {
+ in_data[j][i][s] = zero_point;
+ }
+ }
+ ++i;
+ }
+ // We do not care what goes into the trailing buffer, but we want
+ // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
+ //
+ // It might prove better in optimized code to pad uniformly with
+ // zero_point, and compensate by initializing the summations with the
+ // compensating offset, effectively
+ // ((input_xor - zero_point) ^ input_xor) *
+ // 4 * (8 - ((available_src_rows + 3) >> 2)).
+ for (; i < 8; ++i) {
+ for (int s = 0; s < 4; ++s) {
+ for (int j = 0; j < 8; ++j) {
+ in_data[j][i][s] = input_xor;
+ }
+ }
+ }
+ // We loop through [0, 8) rather than
+ // [0, (available_src_rows + 3) >> 2), since that emulates what we might
+ // do in fully-optimized code.
+ //
+ // i: chunks, j: Layout::kCols, s: Layout::Rows.
+ if (sums_ptr) {
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
+ sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor);
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ for (int s = 0; s < 4; ++s) {
+ trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
+ }
+ }
+ }
+ }
+ }
+
+ packed_ptr += 8 * kNumChunkedSrcRows;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+}
+
+inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols,
+ int src_rows, float* packed_ptr,
+ float* trailing_buf) {
+ using Layout = PackImplFloatSse42::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 1);
+
+ // This packing amounts to tranposition of 8x8 blocks.
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+
+ float in_data[kPackCols][kPackRows];
+
+ const float* src_ptr0 = src_ptr;
+ const float* src_ptr1 = src_ptr0 + src_stride;
+ const float* src_ptr2 = src_ptr1 + src_stride;
+ const float* src_ptr3 = src_ptr2 + src_stride;
+ const float* src_ptr4 = src_ptr3 + src_stride;
+ const float* src_ptr5 = src_ptr4 + src_stride;
+ const float* src_ptr6 = src_ptr5 + src_stride;
+ const float* src_ptr7 = src_ptr6 + src_stride;
+ std::int64_t src_inc0 = 8;
+ std::int64_t src_inc1 = 8;
+ std::int64_t src_inc2 = 8;
+ std::int64_t src_inc3 = 8;
+ std::int64_t src_inc4 = 8;
+ std::int64_t src_inc5 = 8;
+ std::int64_t src_inc6 = 8;
+ std::int64_t src_inc7 = 8;
+ // Handle cases where source does not have kPackDim (8) columns.
+ if (remaining_src_cols < kPackCols) {
+ if (remaining_src_cols <= 0) {
+ src_ptr0 = zerobuf;
+ src_inc0 = 0;
+ }
+ if (remaining_src_cols <= 1) {
+ src_ptr1 = zerobuf;
+ src_inc1 = 0;
+ }
+ if (remaining_src_cols <= 2) {
+ src_ptr2 = zerobuf;
+ src_inc2 = 0;
+ }
+ if (remaining_src_cols <= 3) {
+ src_ptr3 = zerobuf;
+ src_inc3 = 0;
+ }
+ if (remaining_src_cols <= 4) {
+ src_ptr4 = zerobuf;
+ src_inc4 = 0;
+ }
+ if (remaining_src_cols <= 5) {
+ src_ptr5 = zerobuf;
+ src_inc5 = 0;
+ }
+ if (remaining_src_cols <= 6) {
+ src_ptr6 = zerobuf;
+ src_inc6 = 0;
+ }
+ src_ptr7 = zerobuf;
+ src_inc7 = 0;
+ }
+
+ for (int k = 0; k < src_rows; k += kPackRows) {
+ const int available_src_rows = src_rows - k;
+ // Effectively,
+ // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
+ // but treat each case separately.
+ if (available_src_rows >= kPackRows) {
+ for (int i = 0; i < 8; ++i) {
+ in_data[0][i] = src_ptr0[i];
+ in_data[1][i] = src_ptr1[i];
+ in_data[2][i] = src_ptr2[i];
+ in_data[3][i] = src_ptr3[i];
+ in_data[4][i] = src_ptr4[i];
+ in_data[5][i] = src_ptr5[i];
+ in_data[6][i] = src_ptr6[i];
+ in_data[7][i] = src_ptr7[i];
+ }
+ for (int i = 0; i < 8; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ packed_ptr[8 * i + j] = in_data[j][i];
+ }
+ }
+ } else if (available_src_rows > 0) {
+ for (int i = 0; i < available_src_rows; ++i) {
+ in_data[0][i] = src_ptr0[i];
+ in_data[1][i] = src_ptr1[i];
+ in_data[2][i] = src_ptr2[i];
+ in_data[3][i] = src_ptr3[i];
+ in_data[4][i] = src_ptr4[i];
+ in_data[5][i] = src_ptr5[i];
+ in_data[6][i] = src_ptr6[i];
+ in_data[7][i] = src_ptr7[i];
+ }
+ for (int i = available_src_rows; i < kPackRows; ++i) {
+ in_data[0][i] = 0.0f;
+ in_data[1][i] = 0.0f;
+ in_data[2][i] = 0.0f;
+ in_data[3][i] = 0.0f;
+ in_data[4][i] = 0.0f;
+ in_data[5][i] = 0.0f;
+ in_data[6][i] = 0.0f;
+ in_data[7][i] = 0.0f;
+ }
+ // We loop through [0, 7) rather than [0, packed_rows), since that
+ // emulates what we might do in fully-optimized code.
+ // i: (kPackRows - 1), j: kPackCols.
+ for (int i = 0; i < 7; ++i) {
+ for (int j = 0; j < 8; ++j) {
+ trailing_buf[kPackRows * i + j] = in_data[j][i];
+ }
+ }
+ }
+
+ packed_ptr += kPackRows * kPackCols;
+ src_ptr0 += src_inc0;
+ src_ptr1 += src_inc1;
+ src_ptr2 += src_inc2;
+ src_ptr3 += src_inc3;
+ src_ptr4 += src_inc4;
+ src_ptr5 += src_inc5;
+ src_ptr6 += src_inc6;
+ src_ptr7 += src_inc7;
+ }
+}
+
+} // namespace.
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
+ profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)");
+
+ using Layout = PackImpl8bitSse42::Layout;
+ RUY_DCHECK_EQ(Layout::kCols, 8);
+ RUY_DCHECK_EQ(Layout::kRows, 4);
+
+ // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
+ // We process 8 of these chunks at a time, padding short input chunks.
+ static constexpr int kNumRowChunks = 8; // Short input is padded.
+
+ // Each packed block is 4*8, and there are normally 8. The trailing block is
+ // only slightly shorter.
+ constexpr int kTrailingBufSize =
+ kNumRowChunks * Layout::kCols * Layout::kRows;
+ std::int8_t trailing_buf[kTrailingBufSize];
+ memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
+
+ Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride,
+ remaining_src_cols, src_rows, packed_ptr, sums_ptr,
+ trailing_buf);
+
+ constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
+ const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
+ // If the number of source rows is not a multiple of kChunkedRowMask, there
+ // will be data in the trailing buffer,
+ if (trailing_data > 0) {
+ const int non_trailing_rows = src_rows & ~kChunkedRowMask;
+ // Destination "rows" are padded to next highest multiple of Layout::kRows.
+ const int dst_rows = (src_rows + 3) & ~3;
+ const int trailing_rows = dst_rows - non_trailing_rows;
+ memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
+ Layout::kCols * trailing_rows * sizeof(std::int8_t));
+ }
+}
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr) {
+ profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)");
+ static constexpr int kPackCols = 8; // Source cols packed together.
+ static constexpr int kPackRows = 8; // Short input is padded.
+ float trailing_buf[(kPackRows - 1) * kPackCols];
+ if (remaining_src_cols < 8) {
+ memset(trailing_buf, 0, sizeof(trailing_buf));
+ }
+ PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_rows, packed_ptr, trailing_buf);
+
+ const int trailing_rows = src_rows & (kPackRows - 1);
+ if (trailing_rows > 0) {
+ const int non_trailing_rows = src_rows & ~(kPackRows - 1);
+ memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
+ kPackCols * trailing_rows * sizeof(float));
+ }
+}
+
+#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
+
+} // namespace ruy
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
new file mode 100644
index 0000000..b777cc1
--- /dev/null
+++ b/ruy/pack_x86.h
@@ -0,0 +1,461 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// # What is "packing"?
+//
+// Before feeding data to the gemm kernels (the parts of Ruy that do lots
+// of multiply-add operations), Ruy first performs a data transformation (which
+// we call "packing") on the input matrices. This transformation has two main
+// goals:
+// - rearrange data into blocks that are a convenient size/layout for the gemm
+// kernels to consume. This helps make the memory access pattern of the gemm
+// kernel simpler and more contiguous, and puts the data in a layout most
+// convenient for specific arithmetic instructions in the gemm kernel.
+// - compute row/column sums needed for handling quantization with non-symmetric
+// zero points.
+//
+// # Simplified algorithmic analysis of packing
+//
+// Packing is a relatively simple transformation which does a small constant
+// amount of work on each element of an input matrix, and hence for an NxM
+// matrix performs O(N*M) work. If N and M are of the same order, then this is
+// O(N^2) work.
+//
+// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
+// Note that if N, K, and M are all the same order, then the number of
+// multiply-accumulate operations is O(N^3).
+//
+// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
+// case of all dimensions being roughly the same order.
+//
+// # Packing cost can be significant
+//
+// When matrix * matrix multiplications begin to look more like matrix * vector
+// multiplications, packing cost can become significant. We sometimes call these
+// cases "gemv-like".
+//
+// Continuing the algorithmic analysis above, if we consider a case where an
+// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
+// situation is different. In this case, the multiply-accumulate work is only
+// quadratic, so the quadratic cost of packing can be come significant.
+//
+// Another way to say this is that the cost of packing an input matrix (either
+// the LHS or RHS) is amortized across the non-depth dimension of the opposite
+// input matrix. Thus, when the LHS has very few rows or the RHS has very few
+// columns, the cost of packing the opposite input matrix can become
+// significant.
+//
+// As a rough rule of thumb, the cost of packing starts to become significant
+// when either N or M is below 32 (and other dimensions are hundreds), with very
+// significant packing costs at 8 or below. This varies by data type, Path, and
+// tuning, so these numbers are only rough guides.
+//
+// One practical use case that is affected by this is inference of
+// fully connected neural network layers with a low batch size. The weight
+// matrix (which is a constant for inference) is the one affected by significant
+// packing cost.
+//
+// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
+// input matrices that are affected by significant packing costs.
+//
+// # Implementation notes
+//
+// Ruy's packing routines always operate on a range of columns and can be
+// applied to either the LHS or RHS. This is possible because Ruy internally
+// implements a TrMul, so the accumulation along depth is done along columns of
+// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
+// for the LHS is along rows). As another example, we are always computing
+// column sums for quantization (and never row sums, since the LHS is
+// transposed).
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_
+
+#include <cstdint>
+#include <cstring>
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/pack_common.h"
+#include "ruy/path.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+#if RUY_PLATFORM(X86)
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kSse42, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
+ std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[Layout::kCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ Layout::kCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitSse42(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr,
+ sums_ptr);
+ }
+ }
+};
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kSse42, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float> {
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ static void Run(Tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (SSE 4.2 float)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, std::int8_t* packed_ptr,
+ std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
+ std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX2 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[Layout::kCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ Layout::kCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitAvx2(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr,
+ sums_ptr);
+ }
+ }
+};
+
+void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ float, float> {
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
+ static void Run(Tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX2 float)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ Scalar, std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ static constexpr int kHalfLayoutCols =
+ 8; // Half the number of cols in a block.
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitAvx512(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr,
+ sums_ptr);
+ }
+ }
+};
+
+void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows, float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
+ float, float, float> {
+ static void Run(Tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 float)");
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// Note that source and zero buffers can be uint8 type, but in the packing
+// function are reinterpreted as int8, and are XOR-ed with input_xor.
+void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor,
+ const std::int8_t* zerobuf, int src_stride,
+ int remaining_src_cols, int src_rows,
+ std::int8_t* packed_ptr, std::int32_t* sums_ptr);
+
+template <typename Scalar>
+struct PackImpl<Path::kAvxVnni, FixedKernelLayout<Order::kColMajor, 4, 16>,
+ Scalar, std::int8_t, std::int32_t> {
+ static_assert(std::is_same<Scalar, std::int8_t>::value ||
+ std::is_same<Scalar, std::uint8_t>::value,
+ "");
+ using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ static constexpr int kHalfLayoutCols =
+ 8; // Half the number of cols in a block.
+ static constexpr std::int8_t kInputXor =
+ std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
+
+ static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
+ PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
+
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
+ std::int32_t* sums = packed_matrix->sums;
+ Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
+ memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
+ kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
+ int src_stride = src_matrix.layout.stride;
+ const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ std::int8_t* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ Pack8bitAvxVnni(reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
+ reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
+ remaining_src_cols, src_matrix.layout.rows, packed_ptr,
+ sums_ptr);
+ }
+ }
+};
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf,
+ int src_stride, int remaining_src_cols, int src_rows,
+ float* packed_ptr);
+
+template <>
+struct PackImpl<Path::kAvxVnni, FixedKernelLayout<Order::kRowMajor, 1, 16>,
+ float, float, float> {
+ static void Run(Tuning, const Matrix<float>& src_matrix,
+ PackedMatrix<float>* packed_matrix, int start_col,
+ int end_col) {
+ profiler::ScopeLabel label("Pack (AVX-512 float)");
+
+ using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
+ RUY_DCHECK(IsColMajor(src_matrix.layout));
+ RUY_DCHECK(IsColMajor(packed_matrix->layout));
+ RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
+ RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
+ const float zerobuf[Layout::kCols] = {
+ 0.0f}; // Remainder default inits to 0.0f.
+ for (int block_col = start_col; block_col < end_col;
+ block_col += Layout::kCols) {
+ int src_stride = src_matrix.layout.stride;
+ const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
+ int remaining_src_cols = src_matrix.layout.cols - block_col;
+
+ static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
+ float* packed_ptr =
+ packed_matrix->data +
+ packed_matrix->layout.stride * (block_col & block_col_mask);
+ PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols,
+ src_matrix.layout.rows, packed_ptr);
+ }
+ }
+};
+#endif // RUY_PLATFORM(X86)
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_
diff --git a/ruy/path.h b/ruy/path.h
new file mode 100644
index 0000000..7141b16
--- /dev/null
+++ b/ruy/path.h
@@ -0,0 +1,162 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_
+
+#include <cstdint>
+
+#include "ruy/platform.h"
+#include "ruy/size_util.h"
+
+namespace ruy {
+
+// A Path is a choice of implementation path, e.g. between reference code
+// and optimized code, or between different optimized code paths using different
+// instruction sets.
+//
+// It's important that any symbol that depends on such implementation
+// details, is somehow templatized in such a Path, so that different Path values
+// yield different symbols, so we never have the situation where a symbols has
+// multiple inequivalent definitions based on which code paths are compiled.
+// That would be a violation of the ODR (One Definition Rule) which is Undefined
+// Behavior, and one of the most serious issues plaguing both Eigen and
+// gemmlowp.
+//
+// This enum is actually a bit-field: aside from kNone, all other values are
+// powers of two, thus are one bit each. We define bit-wise operators below
+// for this enum. Some places in Ruy accept a Path bit-field where multiple
+// Paths may be selected, while some other places require a single Path (i.e.
+// just one of the enum values here). Typically, user-facing parts of Ruy
+// accept arbitrary bit-fields, allowing the user to compile support for
+// multiple paths and to inform Ruy of all the paths that are to be enabled
+// at runtime; then, typically in dispatch.h, we internally pick one
+// specific path and from there on, internal Ruy code deals with only one
+// path.
+//
+// When a user selects a set of compiled paths, Ruy internally dispatches to the
+// "best" one, which typically means the newest optimized instructions for a
+// given base architecture (such as ARM). Higher values of this enum correspond
+// to "better" code paths within a given base architecture for which Ruy has
+// optimized code paths.
+//
+// Values are reused across architectures.
+// Rationale: Scale better to N architectures, it is good to have small values
+// both for the compile-time logic to select paths, and when manually spelling
+// out Path values, such as when invoking a test or benchmark.
+enum class Path : std::uint8_t {
+ // This is a special null value, representing the absence of any path.
+ kNone = 0,
+ // Reference multiplication code.
+ // The main purpose of this path is to have a very simple standalone Mul
+ // implementation to check against.
+ // This path bypasses almost all of Ruy's internal implementation details.
+ //
+ // This is intended for testing/development.
+ kReference = 0x1,
+ // Standard C++ implementation of Ruy's architecture-specific parts.
+ // Unlike Path::kReference, this path exercises most of Ruy's internal logic.
+ //
+ // This is intended for testing/development.
+ kStandardCpp = 0x2,
+
+#if RUY_PLATFORM(ARM)
+ // ARM architectures.
+ //
+ // Optimized path using a widely available subset of ARM NEON instructions.
+ kNeon = 0x4,
+ // Optimized path making use of ARM NEON dot product instructions that are
+ // available on newer ARM cores.
+ kNeonDotprod = 0x8,
+#endif // RUY_PLATFORM(ARM)
+
+#if RUY_PLATFORM(X86)
+ // x86 architectures.
+ //
+ // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete /
+ // placeholder.
+ // Optimization is not finished. In particular the dimensions of the kernel
+ // blocks can be changed as desired.
+ //
+ // Optimized for SSE 4.2.
+ kSse42 = 0x4,
+ // Optimized for AVX2.
+ kAvx2 = 0x8,
+ // Optimized for AVX-512.
+ kAvx512 = 0x10,
+ // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete /
+ // placeholder.
+ // Optimization is not finished. In particular the dimensions of the kernel
+ // blocks can be changed as desired.
+ //
+ // Optimized for AVX-VNNI.
+ kAvxVnni = 0x20,
+#endif // RUY_PLATFORM(X86)
+};
+
+inline constexpr Path operator|(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) |
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator&(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) &
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator^(Path p, Path q) {
+ return static_cast<Path>(static_cast<std::uint32_t>(p) ^
+ static_cast<std::uint32_t>(q));
+}
+
+inline constexpr Path operator~(Path p) {
+ return static_cast<Path>(~static_cast<std::uint32_t>(p));
+}
+
+inline Path GetMostSignificantPath(Path path_mask) {
+ return static_cast<Path>(round_down_pot(static_cast<int>(path_mask)));
+}
+
+// ruy::kAllPaths represents all Path's that make sense to on a given
+// base architecture.
+#ifdef __linux__
+#if RUY_PLATFORM(NEON_64)
+constexpr Path kAllPaths =
+ Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod;
+#elif RUY_PLATFORM(NEON_32)
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
+#elif RUY_PLATFORM(X86)
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp |
+ Path::kSse42 | Path::kAvx2 | Path::kAvx512 |
+ Path::kAvxVnni;
+#else
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
+#endif
+#else // __linux__
+// We don't know how to do runtime dotprod detection outside of linux for now.
+#if RUY_PLATFORM(NEON)
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon;
+#elif RUY_PLATFORM(X86)
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp |
+ Path::kSse42 | Path::kAvx2 | Path::kAvx512 |
+ Path::kAvxVnni;
+#else
+constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp;
+#endif
+#endif // __linux__
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_
diff --git a/ruy/platform.h b/ruy/platform.h
new file mode 100644
index 0000000..d6e86e6
--- /dev/null
+++ b/ruy/platform.h
@@ -0,0 +1,156 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_
+
+#ifdef __ANDROID_NDK__
+#include <android/ndk-version.h>
+#endif
+
+#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0)
+
+// Architecture-level platform detection.
+//
+// Ruy requires these to be mutually exclusive.
+
+// Detect x86.
+#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \
+ defined(__x86__) || defined(__X86__) || defined(_X86_) || \
+ defined(_M_IX86) || defined(_M_X64)
+#define RUY_DONOTUSEDIRECTLY_X86 1
+#else
+#define RUY_DONOTUSEDIRECTLY_X86 0
+#endif
+
+// Detect ARM 32-bit.
+#ifdef __arm__
+#define RUY_DONOTUSEDIRECTLY_ARM_32 1
+#else
+#define RUY_DONOTUSEDIRECTLY_ARM_32 0
+#endif
+
+// Detect ARM 64-bit.
+#ifdef __aarch64__
+#define RUY_DONOTUSEDIRECTLY_ARM_64 1
+#else
+#define RUY_DONOTUSEDIRECTLY_ARM_64 0
+#endif
+
+// Combined ARM.
+#define RUY_DONOTUSEDIRECTLY_ARM \
+ (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32)
+
+// Feature and capability platform detection.
+//
+// These are mostly sub-selections of architectures.
+
+// Detect NEON. Explicitly avoid emulation, or anything like it, on x86.
+#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86)
+#define RUY_DONOTUSEDIRECTLY_NEON 1
+#else
+#define RUY_DONOTUSEDIRECTLY_NEON 0
+#endif
+
+// Define ARM 32-bit NEON.
+#define RUY_DONOTUSEDIRECTLY_NEON_32 \
+ (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32)
+
+// Define ARM 64-bit NEON.
+// Note: NEON is implied by ARM64, so this define is redundant.
+// It still allows some conveyance of intent.
+#define RUY_DONOTUSEDIRECTLY_NEON_64 \
+ (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64)
+
+// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we
+// may only need to disable this on XCode <= 10.2.
+//
+// Disable when not using Clang-Linux, because too many user issues arise from
+// compilation variations.
+//
+// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only
+// restriction.
+//
+// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm.
+//
+// The Android NDK logic excludes earlier and very broken versions of intrinsics
+// headers.
+#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \
+ (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \
+ !defined(__EMSCRIPTEN__) && \
+ (!defined(__ANDROID_NDK__) || \
+ (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20))))
+#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1
+#else
+#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0
+#endif
+
+// These CPU capabilities will all be true when Skylake, etc, are enabled during
+// compilation.
+#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \
+ defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \
+ defined(__AVX512BW__) && defined(__AVX512VL__)
+#define RUY_DONOTUSEDIRECTLY_AVX512 1
+#else
+#define RUY_DONOTUSEDIRECTLY_AVX512 0
+#endif
+
+#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__)
+#define RUY_DONOTUSEDIRECTLY_AVX2 1
+#else
+#define RUY_DONOTUSEDIRECTLY_AVX2 0
+#endif
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// Note does not check for LZCNT or POPCNT.
+#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \
+ RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__)
+#define RUY_DONOTUSEDIRECTLY_SSE42 1
+#else
+#define RUY_DONOTUSEDIRECTLY_SSE42 0
+#endif
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// Note that defined(__AVX512VBMI2__) can be false for compilation with
+// -march=cascadelake.
+// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__).
+#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \
+ defined(__AVX512VNNI__)
+#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1
+#else
+#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0
+#endif
+
+// Detect APPLE.
+#ifdef __APPLE__
+#define RUY_DONOTUSEDIRECTLY_APPLE 1
+#else
+#define RUY_DONOTUSEDIRECTLY_APPLE 0
+#endif
+
+// Detect Emscripten, typically Wasm.
+#ifdef __EMSCRIPTEN__
+#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1
+#else
+#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0
+#endif
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_
diff --git a/ruy/pmu.cc b/ruy/pmu.cc
new file mode 100644
index 0000000..1d87b1f
--- /dev/null
+++ b/ruy/pmu.cc
@@ -0,0 +1,281 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/pmu.h"
+
+#include "ruy/check_macros.h"
+
+#ifdef __linux__
+#include <asm/unistd.h>
+#include <linux/perf_event.h>
+#include <sys/ioctl.h>
+#include <syscall.h>
+#include <unistd.h>
+
+#include <cstdio>
+#endif
+
+#include <algorithm>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+
+namespace ruy {
+
+// Linux-specific. Not ARM-specific.
+#ifdef __linux__
+class PerfEvent {
+ public:
+ PerfEvent(std::uint32_t type, std::uint64_t config) {
+ perf_event_attr pe;
+ memset(&pe, 0, sizeof(pe));
+ pe.size = sizeof(pe);
+ pe.type = type;
+ pe.config = config;
+ pe.disabled = 1;
+ pe.exclude_kernel = 1;
+ pe.exclude_hv = 1;
+ pe.inherit = 1;
+ fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0);
+ if (fd_ == -1) {
+ fprintf(stderr, "perf_event_open failed for config 0x%lx\n",
+ static_cast<unsigned long>(config));
+ // abort();
+ }
+ }
+
+ ~PerfEvent() {
+ RUY_CHECK(!started_);
+ close(fd_);
+ }
+
+ void Start() {
+ RUY_CHECK(!started_);
+ started_ = true;
+ ioctl(fd_, PERF_EVENT_IOC_RESET, 0);
+ ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0);
+ count_at_start_ = Read();
+ }
+
+ void Stop() {
+ RUY_CHECK(started_);
+ started_ = false;
+ ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0);
+ count_at_stop_ = Read();
+ }
+
+ std::int64_t Count() const {
+ RUY_CHECK(!started_);
+ return count_at_stop_ - count_at_start_;
+ }
+
+ private:
+ std::int64_t Read() const {
+ std::int64_t count;
+ RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1);
+ return count;
+ }
+ std::int64_t count_at_start_ = -1;
+ std::int64_t count_at_stop_ = -1;
+ bool started_ = false;
+ int fd_ = -1;
+};
+#else
+// Placeholder implementation to at least compile outside of linux.
+#define PERF_TYPE_RAW 0
+class PerfEvent {
+ public:
+ PerfEvent(std::uint32_t, std::uint64_t) {}
+ ~PerfEvent() {}
+ void Start() {}
+ void Stop() {}
+ std::int64_t Count() const { return 0; }
+};
+#endif
+
+// ARM-specific. Query ARM PMU counters as Linux perf events using
+// PERF_TYPE_RAW.
+namespace arm_pmuv3 {
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-const-variable"
+
+// These event numbers are listed in the ARMv8 architecture reference manual.
+constexpr std::uint16_t L1I_CACHE_REFILL = 0x01;
+constexpr std::uint16_t L1I_TLB_REFILL = 0x02;
+constexpr std::uint16_t L1D_CACHE_REFILL = 0x03;
+constexpr std::uint16_t L1D_CACHE = 0x04;
+constexpr std::uint16_t L1D_TLB_REFILL = 0x05;
+constexpr std::uint16_t LD_RETIRED = 0x06;
+constexpr std::uint16_t ST_RETIRED = 0x07;
+constexpr std::uint16_t INST_RETIRED = 0x08;
+constexpr std::uint16_t EXC_TAKEN = 0x09;
+constexpr std::uint16_t EXC_RETURN = 0x0A;
+constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B;
+constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C;
+constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D;
+constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E;
+constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F;
+constexpr std::uint16_t BR_MIS_PRED = 0x10;
+constexpr std::uint16_t CPU_CYCLES = 0x11;
+constexpr std::uint16_t BR_PRED = 0x12;
+constexpr std::uint16_t MEM_ACCESS = 0x13;
+constexpr std::uint16_t L1I_CACHE = 0x14;
+constexpr std::uint16_t L1D_CACHE_WB = 0x15;
+constexpr std::uint16_t L2D_CACHE = 0x16;
+constexpr std::uint16_t L2D_CACHE_REFILL = 0x17;
+constexpr std::uint16_t L2D_CACHE_WB = 0x18;
+constexpr std::uint16_t BUS_ACCESS = 0x19;
+constexpr std::uint16_t MEMORY_ERROR = 0x1A;
+constexpr std::uint16_t INST_SPEC = 0x1B;
+constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C;
+constexpr std::uint16_t BUS_CYCLES = 0x1D;
+constexpr std::uint16_t CHAIN = 0x1E;
+constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F;
+constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20;
+constexpr std::uint16_t BR_RETIRED = 0x21;
+constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22;
+constexpr std::uint16_t STALL_FRONTEND = 0x23;
+constexpr std::uint16_t STALL_BACKEND = 0x24;
+constexpr std::uint16_t L1D_TLB = 0x25;
+constexpr std::uint16_t L1I_TLB = 0x26;
+constexpr std::uint16_t L2I_CACHE = 0x27;
+constexpr std::uint16_t L2I_CACHE_REFILL = 0x28;
+constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29;
+constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A;
+constexpr std::uint16_t L3D_CACHE = 0x2B;
+constexpr std::uint16_t L3D_CACHE_WB = 0x2C;
+constexpr std::uint16_t L2D_TLB_REFILL = 0x2D;
+constexpr std::uint16_t L2I_TLB_REFILL = 0x2E;
+constexpr std::uint16_t L2D_TLB = 0x2F;
+constexpr std::uint16_t L2I_TLB = 0x30;
+constexpr std::uint16_t LL_CACHE = 0x32;
+constexpr std::uint16_t LL_CACHE_MISS = 0x33;
+constexpr std::uint16_t DTLB_WALK = 0x34;
+constexpr std::uint16_t LL_CACHE_RD = 0x36;
+constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37;
+
+// Additional implementation-defined events found by googling around.
+constexpr std::uint16_t L1D_CACHE_RD = 0x40;
+constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42;
+constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C;
+constexpr std::uint16_t L1D_TLB_RD = 0x4E;
+constexpr std::uint16_t L2D_CACHE_RD = 0x50;
+constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52;
+constexpr std::uint16_t BUS_ACCESS_RD = 0x60;
+constexpr std::uint16_t MEM_ACCESS_RD = 0x66;
+constexpr std::uint16_t L3D_CACHE_RD = 0xA0;
+constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2;
+
+#pragma GCC diagnostic pop
+
+}; // namespace arm_pmuv3
+
+class PmuEventsPrivate {
+ public:
+ PmuEventsPrivate()
+ : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL),
+ l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL),
+ l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL),
+ ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS),
+ l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL),
+ l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL),
+ stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND),
+ stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND),
+ br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {}
+
+ private:
+ friend class PmuEvents;
+ PerfEvent l1d_cache_refill;
+ PerfEvent l2d_cache_refill;
+ PerfEvent l3d_cache_refill;
+ PerfEvent ll_cache_miss;
+ PerfEvent l1d_tlb_refill;
+ PerfEvent l2d_tlb_refill;
+ PerfEvent stall_frontend;
+ PerfEvent stall_backend;
+ PerfEvent br_mis_pred;
+};
+
+PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {}
+PmuEvents::~PmuEvents() { delete priv; }
+
+void PmuEvents::StartRecording() {
+ priv->l1d_cache_refill.Start();
+ priv->l2d_cache_refill.Start();
+ priv->l3d_cache_refill.Start();
+ priv->ll_cache_miss.Start();
+ priv->l1d_tlb_refill.Start();
+ priv->l2d_tlb_refill.Start();
+ priv->stall_frontend.Start();
+ priv->stall_backend.Start();
+ priv->br_mis_pred.Start();
+}
+
+void PmuEvents::StopRecording() {
+ priv->l1d_cache_refill.Stop();
+ priv->l2d_cache_refill.Stop();
+ priv->l3d_cache_refill.Stop();
+ priv->ll_cache_miss.Stop();
+ priv->l1d_tlb_refill.Stop();
+ priv->l2d_tlb_refill.Stop();
+ priv->stall_frontend.Stop();
+ priv->stall_backend.Stop();
+ priv->br_mis_pred.Stop();
+}
+
+float PmuEvents::BranchMispredictionCount() const {
+ return static_cast<float>(priv->br_mis_pred.Count());
+}
+
+float PmuEvents::FrontendStallCount() const {
+ return static_cast<float>(priv->stall_frontend.Count());
+}
+
+float PmuEvents::BackendStallCount() const {
+ return static_cast<float>(priv->stall_backend.Count());
+}
+
+float PmuEvents::L1RefillCount() const {
+ return static_cast<float>(priv->l1d_cache_refill.Count());
+}
+
+float PmuEvents::L2RefillCount() const {
+ return static_cast<float>(priv->l2d_cache_refill.Count());
+}
+
+float PmuEvents::L3RefillCount() const {
+ // Important: this was discovered in the context of the above experiments,
+ // which also tested the _RD variants of these counters. So it's possible that
+ // it's just not needed here with the default (non _RD) counters.
+ //
+ // Some CPUs implement LL_CACHE_MISS[_RD], some implement
+ // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is
+ // zero, or they roughly both agree with each other. Therefore, taking the max
+ // of them is a reasonable way to get something more portable across various
+ // CPUs.
+ return static_cast<float>(
+ std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count()));
+}
+
+float PmuEvents::L1TLBRefillCount() const {
+ return static_cast<float>(priv->l1d_tlb_refill.Count());
+}
+
+float PmuEvents::L2TLBRefillCount() const {
+ return static_cast<float>(priv->l2d_tlb_refill.Count());
+}
+
+} // namespace ruy
diff --git a/ruy/pmu.h b/ruy/pmu.h
new file mode 100644
index 0000000..721c1d5
--- /dev/null
+++ b/ruy/pmu.h
@@ -0,0 +1,44 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_
+
+namespace ruy {
+
+class PmuEventsPrivate;
+
+class PmuEvents {
+ public:
+ PmuEvents();
+ ~PmuEvents();
+ void StartRecording();
+ void StopRecording();
+ float L1RefillCount() const;
+ float L2RefillCount() const;
+ float L3RefillCount() const;
+ float BranchMispredictionCount() const;
+ float FrontendStallCount() const;
+ float BackendStallCount() const;
+ float L1TLBRefillCount() const;
+ float L2TLBRefillCount() const;
+
+ private:
+ PmuEventsPrivate* priv = nullptr;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_
diff --git a/ruy/prepack.h b/ruy/prepack.h
new file mode 100644
index 0000000..4bfc9ed
--- /dev/null
+++ b/ruy/prepack.h
@@ -0,0 +1,108 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implementation of low-level pre-packing API.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_
+
+#include <cstddef>
+#include <functional>
+
+#include "ruy/check_macros.h"
+#include "ruy/context.h"
+#include "ruy/dispatch.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/path.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/spec.h"
+#include "ruy/trmul.h"
+#include "ruy/trmul_params.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void PrePackForMulInternal(const Matrix<LhsScalar>& lhs,
+ const Matrix<RhsScalar>& rhs, const Spec& spec,
+ Context* context, Matrix<DstScalar>* dst,
+ SidePair<PrepackedMatrix*> prepacked,
+ std::function<void*(std::size_t)> alloc_fn) {
+ profiler::ScopeLabel label("PrePackForMul");
+ Path the_path = context->GetPathToTake<CompiledPaths>();
+ RUY_CHECK_NE(the_path, Path::kReference);
+ constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
+ Matrix<LhsScalar> transposed_lhs(lhs);
+ Transpose(&transposed_lhs);
+ TrMulParams params;
+ CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
+ the_path, &params);
+
+ const SidePair<int> origin{0, 0};
+ const SidePair<int> rounded_dims{params.packed[Side::kLhs].layout.cols,
+ params.packed[Side::kRhs].layout.cols};
+
+ Tuning tuning = context->GetMainThreadTuning();
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (prepacked[side]) {
+ prepacked[side]->data_size = DataSize(params.packed[side]);
+ prepacked[side]->sums_size = SumsSize(params.packed[side]);
+ prepacked[side]->data = alloc_fn(prepacked[side]->data_size);
+ prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size);
+ params.packed[side].data = prepacked[side]->data;
+ params.packed[side].sums = prepacked[side]->sums;
+ params.RunPack(side, tuning, origin[side], rounded_dims[side]);
+ }
+ }
+}
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs,
+ const Matrix<RhsScalar>& rhs, const Spec& spec,
+ Context* context, Matrix<DstScalar>* dst,
+ SidePair<PrepackedMatrix*> prepacked) {
+ profiler::ScopeLabel label("MulWithPrepacked");
+
+ EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
+ EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
+ dst->zero_point);
+
+ Path the_path = context->GetPathToTake<CompiledPaths>();
+ RUY_CHECK_NE(the_path, Path::kReference);
+ constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
+ Matrix<LhsScalar> transposed_lhs(lhs);
+ Transpose(&transposed_lhs);
+ TrMulParams params;
+ CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
+ the_path, &params);
+
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (prepacked[side]) {
+ params.packed[side].data = prepacked[side]->data;
+ params.packed[side].sums = prepacked[side]->sums;
+ params.is_prepacked[side] = true;
+ }
+ }
+
+ TrMul(&params, context);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_
diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc
new file mode 100644
index 0000000..020fdf7
--- /dev/null
+++ b/ruy/prepacked_cache.cc
@@ -0,0 +1,82 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/prepacked_cache.h"
+
+#include "ruy/matrix.h"
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+
+using CacheIterator = PrepackedCache::CacheIterator;
+
+// Looks for an entry with `key`. If found, update its time stamp.
+CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) {
+ auto itr = cache_.find(key);
+ // If found, update with new access time for this entry.
+ if (itr != cache_.end()) {
+ const TimePoint time = CacheNow();
+ itr->second.second = time;
+ }
+ return itr;
+}
+
+void PrepackedCache::Insert(const CacheKey &key,
+ const PrepackedMatrix &matrix) {
+ // Calculate size of this new item.
+ const size_t size_bytes = matrix.data_size + matrix.sums_size;
+
+ // While we are above the threshold of ejection, eject the LRU entry.
+ while (!cache_.empty() &&
+ ((TotalSize() + size_bytes) > ejection_threshold_)) {
+ EjectOne();
+ }
+ DoInsert(key, matrix);
+ cache_size_ += matrix.data_size + matrix.sums_size;
+}
+
+void PrepackedCache::EjectOne() {
+ TimePoint oldest_time = CacheNow();
+ auto oldest = cache_.begin();
+ {
+ profiler::ScopeLabel label("PepackedCacheEjection");
+ for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) {
+ if (itr->second.second < oldest_time) {
+ oldest_time = itr->second.second;
+ oldest = itr;
+ }
+ }
+ }
+ PrepackedMatrix &pmatrix = oldest->second.first;
+ cache_size_ -= pmatrix.data_size;
+ cache_size_ -= pmatrix.sums_size;
+ allocator_.Free(pmatrix.data);
+ allocator_.Free(pmatrix.sums);
+ cache_.erase(oldest);
+}
+
+void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) {
+ pmatrix->data = allocator_.Alloc(pmatrix->data_size);
+ pmatrix->sums = allocator_.Alloc(pmatrix->sums_size);
+}
+
+void PrepackedCache::DoInsert(const CacheKey &key,
+ const PrepackedMatrix &matrix) {
+ const TimePoint t = CacheNow();
+ const MatrixWithTimeStamp mts({matrix, t});
+ cache_.insert({key, mts});
+}
+
+} // namespace ruy
diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h
new file mode 100644
index 0000000..eedd7e4
--- /dev/null
+++ b/ruy/prepacked_cache.h
@@ -0,0 +1,130 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_
+
+#include <cstddef>
+#include <iostream>
+#include <map>
+#include <queue>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/matrix.h"
+#include "ruy/time.h"
+
+namespace ruy {
+
+namespace detail {
+
+// Tracks a set of blocks allocated from the underlying system allocator.
+class SystemBlockAllocator {
+ public:
+ void *Alloc(std::ptrdiff_t num_bytes) {
+ void *p = detail::SystemAlignedAlloc(num_bytes);
+ blocks_.push_back(p);
+ return p;
+ }
+
+ void Free(void *block) {
+ for (auto it = blocks_.begin(); it != blocks_.end(); ++it) {
+ if (*it == block) {
+ detail::SystemAlignedFree(block);
+ blocks_.erase(it);
+ return;
+ }
+ }
+ RUY_DCHECK(false); // Trying to free pointer we did not allocate.
+ }
+
+ ~SystemBlockAllocator() {
+ for (void *block : blocks_) {
+ detail::SystemAlignedFree(block);
+ }
+ }
+
+ private:
+ std::vector<void *> blocks_;
+};
+
+} // namespace detail
+
+enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul };
+
+// "Low effort" Least Recently Used Cache for Prepacked Matrices
+// A cache mechanism for prepacked matrices that ejects oldest entries.
+// The implementation is "low effort" in the following ways:
+// - we just linearly search for the oldest entry when doing an ejection
+// - the ejection policy is very simple: if the new size would be above the
+// . threshold, we will eject entries until the size is below the threshold.
+// Current use cases (RNNs with GEMV operations) indicate that ejection is rare
+// and memory constraints are tight, so we devote no additional storage to the
+// LRU mechanism and accept O(n) search to eject oldest entry. In practice,
+// the number of total entries has not been shown to be large.
+// This class is not thread safe. In Ruy, memory allocation for packed matrices
+// is done in a single threaded context and the actual packing activity may
+// be done in a multi-threaded context.
+class PrepackedCache {
+ public:
+ static constexpr int kDefaultEjectionThresholdBytes = 1 << 28;
+
+ using CacheKey = std::pair<void *, void *>;
+
+ using MatrixWithTimeStamp = std::pair<PrepackedMatrix, TimePoint>;
+
+ using CacheIterator = std::map<CacheKey, MatrixWithTimeStamp>::const_iterator;
+
+ using AlignedAllocator = detail::AlignedAllocator;
+
+ explicit PrepackedCache(
+ int32_t ejection_threshold = kDefaultEjectionThresholdBytes)
+ : ejection_threshold_(ejection_threshold), cache_size_(0) {}
+
+ // Looks for an entry with `key`. If found, update its time stamp.
+ CacheIterator FindAndUpdate(const CacheKey &key);
+
+ // Returns end iterator for internal cache. The iterator type is appropriate
+ // to use with `FindAndUpdate`.
+ CacheIterator cend() const { return cache_.end(); }
+
+ // Returns the total size (in bytes) of data held in this cache.
+ int TotalSize() const { return cache_size_; }
+
+ // All calls to get current TimePoints go through here.
+ // TODO(b/145625614) Profile timestamps on relevant models to see if
+ // this level of granularity is sufficient. CoarseNow is cheap so
+ // it would be nice to keep it.
+ TimePoint CacheNow() const { return CoarseNow(); }
+
+ // Performs the memory allocation for the `data` and `sums` members of a
+ // PrepackedMatrix.
+ void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix);
+
+ // Adds the PrepackedMatrix to the cache, possibly ejecting other values.
+ void Insert(const CacheKey &key, const PrepackedMatrix &matrix);
+
+ private:
+ void EjectOne();
+ void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix);
+ detail::SystemBlockAllocator allocator_;
+ std::map<CacheKey, MatrixWithTimeStamp> cache_;
+ const int32_t ejection_threshold_;
+ size_t cache_size_;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_
diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc
new file mode 100644
index 0000000..a65841e
--- /dev/null
+++ b/ruy/prepacked_cache_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/prepacked_cache.h"
+
+#include <thread> // NOLINT(build/c++11)
+
+#include "testing/base/public/gunit.h"
+#include "ruy/ruy.h"
+#include "ruy/time.h"
+
+namespace ruy {
+namespace {
+
+TEST(PrepackedCacheTest, TestCacheEjection) {
+ // Create the cache.
+ PrepackedCache prepacked_cache(32);
+ // Allocate the prepacked matrix.
+ PrepackedMatrix mat1;
+ mat1.data_size = 16;
+ mat1.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat1);
+ auto cache_key1 = std::make_pair(nullptr, mat1.data);
+ prepacked_cache.Insert(cache_key1, mat1);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ // Get a time point after the insertion into the cache.
+ TimePoint current = CoarseNow();
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1);
+ EXPECT_NE(itr, prepacked_cache.cend());
+ // By finding mat1, we updated its timestamp. Verify that `current` is older
+ // than the time stamp now associated with mat1.
+ EXPECT_LT(current, itr->second.second);
+ PrepackedMatrix mat2;
+ mat2.data_size = 8;
+ mat2.sums_size = 4;
+ prepacked_cache.AllocatePrepackedMatrix(&mat2);
+
+ auto cache_key2 = std::make_pair(nullptr, mat2.data);
+ prepacked_cache.Insert(cache_key2, mat2);
+ // The cache size was exceeded by inserting mat2. Ensure that mat1 was
+ // ejected.
+ EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
+}
+
+TEST(PrepackedCacheTest, TestCacheBasic) {
+ // Create the cache.
+ PrepackedCache prepacked_cache(48);
+ // Allocate the prepacked matrix.
+ PrepackedMatrix mat1;
+ mat1.data_size = 16;
+ mat1.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat1);
+
+ auto cache_key1 = std::make_pair(nullptr, mat1.data);
+ prepacked_cache.Insert(cache_key1, mat1);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
+
+ PrepackedMatrix mat2;
+ mat2.data_size = 8;
+ mat2.sums_size = 4;
+ prepacked_cache.AllocatePrepackedMatrix(&mat2);
+
+ auto cache_key2 = std::make_pair(nullptr, mat2.data);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+ prepacked_cache.Insert(cache_key2, mat2);
+ // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
+ // ejected.
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
+}
+
+TEST(PrepackedCacheTest, TestCacheEjection2) {
+ // Create the cache.
+ PrepackedCache prepacked_cache(73);
+ // Allocate the prepacked matrix 1.
+ PrepackedMatrix mat1;
+ mat1.data_size = 16;
+ mat1.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat1);
+ auto cache_key1 = std::make_pair(nullptr, mat1.data);
+ prepacked_cache.Insert(cache_key1, mat1);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ // Allocate the prepacked matrix 2.
+ PrepackedMatrix mat2;
+ mat2.data_size = 16;
+ mat2.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat2);
+ auto cache_key2 = std::make_pair(nullptr, mat2.data);
+ prepacked_cache.Insert(cache_key2, mat2);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ // Allocate the prepacked matrix 3.
+ PrepackedMatrix mat31;
+ mat31.data_size = 16;
+ mat31.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat31);
+ auto cache_key3 = std::make_pair(nullptr, mat31.data);
+ prepacked_cache.Insert(cache_key3, mat31);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ // The next insertion will cause the cache size to go over the ejection
+ // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ // Allocate the prepacked matrix 4.
+ PrepackedMatrix mat4;
+ mat4.data_size = 16;
+ mat4.sums_size = 8;
+ prepacked_cache.AllocatePrepackedMatrix(&mat4);
+ auto cache_key4 = std::make_pair(nullptr, mat4.data);
+ prepacked_cache.Insert(cache_key4, mat4);
+ std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not.
+ EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend());
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
+ EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
+}
+
+TEST(PrepackedCacheTest, TestCacheOnCacheable) {
+ // Create context and set the cache policy
+ ruy::Context context;
+ context.cache_policy = ruy::kCacheLHSOnNarrowMul;
+ PrepackedCache* cache = context.GetPrepackedCache();
+ EXPECT_EQ(cache->TotalSize(), 0);
+
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+
+ ruy::BasicSpec<float, float> spec;
+ // Perform the multiplication and confirm no caching occurred.
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
+ EXPECT_EQ(cache->TotalSize(), 0);
+
+ // Set cacheable for the LHS, repeat the multiplication, and see
+ // that caching did occur.
+ lhs.cacheable = true;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
+ EXPECT_NE(cache->TotalSize(), 0);
+}
+
+TEST(PrepackedCacheTest, TestClearCache) {
+ // Create context and set the cache policy
+ ruy::Context context;
+ context.cache_policy = ruy::kCacheLHSOnNarrowMul;
+ PrepackedCache* cache = context.GetPrepackedCache();
+ EXPECT_EQ(cache->TotalSize(), 0);
+
+ const float lhs_data[] = {1, 2, 3, 4};
+ const float rhs_data[] = {1, 2};
+ float dst_data[4];
+
+ ruy::Matrix<float> lhs;
+ ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
+ lhs.data = lhs_data;
+ ruy::Matrix<float> rhs;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
+ rhs.data = rhs_data;
+ ruy::Matrix<float> dst;
+ ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
+ dst.data = dst_data;
+
+ ruy::BasicSpec<float, float> spec;
+ // Set cacheable for the LHS and see that caching occurs.
+ lhs.cacheable = true;
+ ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
+ EXPECT_NE(cache->TotalSize(), 0);
+
+ // Clear the cache via the Context.
+ context.ClearPrepackedCache();
+ // Verify that the cache is now empty.
+ cache = context.GetPrepackedCache();
+ EXPECT_EQ(cache->TotalSize(), 0);
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD
new file mode 100644
index 0000000..b0af802
--- /dev/null
+++ b/ruy/profiler/BUILD
@@ -0,0 +1,52 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+config_setting(
+ name = "ruy_profiler",
+ define_values = {"ruy_profiler": "true"},
+)
+
+cc_library(
+ name = "instrumentation",
+ srcs = ["instrumentation.cc"],
+ hdrs = ["instrumentation.h"],
+ defines = select({
+ ":ruy_profiler": ["RUY_PROFILER"],
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "profiler",
+ srcs = [
+ "profiler.cc",
+ "treeview.cc",
+ ],
+ hdrs = [
+ "profiler.h",
+ "treeview.h",
+ ],
+ deps = [":instrumentation"],
+)
+
+cc_library(
+ name = "test_instrumented_library",
+ testonly = True,
+ srcs = ["test_instrumented_library.cc"],
+ hdrs = ["test_instrumented_library.h"],
+ deps = [":instrumentation"],
+)
+
+cc_test(
+ name = "test",
+ srcs = ["test.cc"],
+ deps = [
+ ":profiler",
+ ":test_instrumented_library",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md
new file mode 100644
index 0000000..8d79025
--- /dev/null
+++ b/ruy/profiler/README.md
@@ -0,0 +1,149 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+## Overview
+
+The present directory is the "ruy profiler". As a time profiler, it allows to
+measure where code is spending time.
+
+Contrary to most typical profilers, what it samples is not real call stacks, but
+"pseudo-stacks" which are just simple data structures constructed from within
+the program being profiled. Using this profiler requires manually instrumenting
+code to construct such pseudo-stack information.
+
+Another unusual characteristic of this profiler is that it uses only the C++11
+standard library. It does not use any non-portable feature, in particular it
+does not rely on signal handlers. The sampling is performed by a thread, the
+"profiler thread".
+
+A discussion of pros/cons of this approach is appended below.
+
+## How to use this profiler
+
+### How to instrument code
+
+An example of instrumented code is given in `test_instrumented_library.cc`.
+
+Code is instrumented by constructing `ScopeLabel` objects. These are RAII
+helpers, ensuring that the thread pseudo-stack contains the label during their
+lifetime. In the most common use case, one would construct such an object at the
+start of a function, so that its scope is the function scope and it allows to
+measure how much time is spent in this function.
+
+```c++
+#include "ruy/profiler/instrumentation.h"
+
+...
+
+void SomeFunction() {
+ ruy::profiling::ScopeLabel function_label("SomeFunction");
+ ... do something ...
+}
+```
+
+A `ScopeLabel` may however have any scope, for instance:
+
+```c++
+if (some_case) {
+ ruy::profiling::ScopeLabel extra_work_label("Some more work");
+ ... do some more work ...
+}
+```
+
+The string passed to the `ScopeLabel` constructor must be just a pointer to a
+literal string (a `char*` pointer). The profiler will assume that these pointers
+stay valid until the profile is finalized.
+
+However, that literal string may be a `printf` format string, and labels may
+have up to 4 parameters, of type `int`. For example:
+
+```c++
+void SomeFunction(int size) {
+ ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size);
+
+```
+
+### How to run the profiler
+
+Profiling instrumentation is a no-op unless the preprocessor token
+`RUY_PROFILER` is defined, so defining it is the first step when actually
+profiling. When building with Bazel, the preferred way to enable that is to pass
+this flag on the Bazel command line:
+
+```
+--define=ruy_profiler=true
+```
+
+To actually profile a code scope, it is enough to construct a `ScopeProfile`
+object, also a RAII helper. It will start the profiler on construction, and on
+destruction it will terminate the profiler and report the profile treeview on
+standard output by default. Example:
+
+```c++
+void SomeProfiledBenchmark() {
+ ruy::profiling::ScopeProfile profile;
+
+ CallSomeInstrumentedCode();
+}
+```
+
+An example is provided by the `:test` target in the present directory. Run it
+with `--define=ruy_profiler=true` as explained above:
+
+```
+bazel run -c opt \
+ --define=ruy_profiler=true \
+ //tensorflow/lite/experimental/ruy/profiler:test
+```
+
+The default behavior dumping the treeview on standard output may be overridden
+by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor.
+This causes the tree-view to be stored in that `TreeView` object, where it may
+be accessed an manipulated using the functions declared in `treeview.h`. The
+aforementioned `:test` provides examples for doing so.
+
+## Advantages and inconvenients
+
+Compared to a traditional profiler, e.g. Linux's "perf", the present kind of
+profiler has the following inconvenients:
+
+* Requires manual instrumentation of code being profiled.
+* Substantial overhead, modifying the performance characteristics of the code
+ being measured.
+* Questionable accuracy.
+
+But also the following advantages:
+
+* Profiling can be driven from within a benchmark program, allowing the entire
+ profiling procedure to be a single command line.
+* Not relying on symbol information removes removes exposure to toolchain
+ details and means less hassle in some build environments, especially
+ embedded/mobile (single command line to run and profile, no symbols files
+ required).
+* Fully portable (all of this is standard C++11).
+* Fully testable (see `:test`). Profiling becomes just another feature of the
+ code like any other.
+* Customized instrumentation can result in easier to read treeviews (only
+ relevant functions, and custom labels may be more readable than function
+ names).
+* Parametrized/formatted labels allow to do things that aren't possible with
+ call-stack-sampling profilers. For example, break down a profile where much
+ time is being spent in matrix multiplications, by the various matrix
+ multiplication shapes involved.
+
+The philosophy underlying this profiler is that software performance depends on
+software engineers profiling often, and a key factor limiting that in practice
+is the difficulty or cumbersome aspects of profiling with more serious profilers
+such as Linux's "perf", especially in embedded/mobile development: multiple
+command lines are involved to copy symbol files to devices, retrieve profile
+data from the device, etc. In that context, it is useful to make profiling as
+easy as benchmarking, even on embedded targets, even if the price to pay for
+that is lower accuracy, higher overhead, and some intrusive instrumentation
+requirement.
+
+Another key aspect determining what profiling approach is suitable for a given
+context, is whether one already has a-priori knowledge of where much of the time
+is likely being spent. When one has such a-priori knowledge, it is feasible to
+instrument the known possibly-critical code as per the present approach. On the
+other hand, in situations where one doesn't have such a-priori knowledge, a real
+profiler such as Linux's "perf" allows to right away get a profile of real
+stacks, from just symbol information generated by the toolchain.
diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc
new file mode 100644
index 0000000..f03f667
--- /dev/null
+++ b/ruy/profiler/instrumentation.cc
@@ -0,0 +1,130 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/instrumentation.h"
+
+#ifdef RUY_PROFILER
+
+namespace ruy {
+namespace profiler {
+
+void Label::operator=(const Label& other) {
+ format_ = other.format_;
+ args_count_ = other.args_count_;
+ for (int i = 0; i < args_count_; i++) {
+ args_[i] = other.args_[i];
+ }
+}
+
+bool Label::operator==(const Label& other) const {
+ if (std::string(format_) != std::string(other.format_)) {
+ return false;
+ }
+ if (args_count_ != other.args_count_) {
+ return false;
+ }
+ for (int i = 0; i < args_count_; i++) {
+ if (args_[i] != other.args_[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string Label::Formatted() const {
+ static constexpr int kBufSize = 256;
+ char buf[kBufSize];
+ if (args_count_ == 0) {
+ return format_;
+ }
+ if (args_count_ == 1) {
+ snprintf(buf, kBufSize, format_, args_[0]);
+ } else if (args_count_ == 2) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1]);
+ } else if (args_count_ == 3) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]);
+ } else if (args_count_ == 4) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]);
+ } else {
+ abort();
+ }
+ return buf;
+}
+
+namespace detail {
+
+std::mutex* GlobalsMutex() {
+ static std::mutex mutex;
+ return &mutex;
+}
+
+bool& GlobalIsProfilerRunning() {
+ static bool b;
+ return b;
+}
+
+std::vector<ThreadStack*>* GlobalAllThreadStacks() {
+ static std::vector<ThreadStack*> all_stacks;
+ return &all_stacks;
+}
+
+ThreadStack* ThreadLocalThreadStack() {
+ thread_local static ThreadStack thread_stack;
+ return &thread_stack;
+}
+
+ThreadStack::ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ static std::uint32_t global_next_thread_stack_id = 0;
+ stack_.id = global_next_thread_stack_id++;
+ GlobalAllThreadStacks()->push_back(this);
+}
+
+ThreadStack::~ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ std::vector<ThreadStack*>* all_stacks = GlobalAllThreadStacks();
+ for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) {
+ if (*it == this) {
+ all_stacks->erase(it);
+ return;
+ }
+ }
+}
+int GetBufferSize(const Stack& stack) {
+ return sizeof(stack.id) + sizeof(stack.size) +
+ stack.size * sizeof(stack.labels[0]);
+}
+
+void CopyToBuffer(const Stack& stack, char* dst) {
+ memcpy(dst, &stack.id, sizeof(stack.id));
+ dst += sizeof(stack.id);
+ memcpy(dst, &stack.size, sizeof(stack.size));
+ dst += sizeof(stack.size);
+ memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0]));
+}
+
+void ReadFromBuffer(const char* src, Stack* stack) {
+ memcpy(&stack->id, src, sizeof(stack->id));
+ src += sizeof(stack->id);
+ memcpy(&stack->size, src, sizeof(stack->size));
+ src += sizeof(stack->size);
+ memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0]));
+}
+
+} // namespace detail
+} // namespace profiler
+} // namespace ruy
+
+#endif
diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h
new file mode 100644
index 0000000..a9046d4
--- /dev/null
+++ b/ruy/profiler/instrumentation.h
@@ -0,0 +1,203 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
+
+#ifdef RUY_PROFILER
+#include <cstdio>
+#include <mutex>
+#include <vector>
+#endif
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// A label is how a code scope is annotated to appear in profiles.
+// The stacks that are sampled by the profiler are stacks of such labels.
+// A label consists of a literal string, plus optional integer arguments.
+class Label {
+ public:
+ Label() {}
+ template <typename... Args>
+ explicit Label(Args... args) {
+ Set(args...);
+ }
+ void Set(const char* format) {
+ format_ = format;
+ args_count_ = 0;
+ }
+ template <typename... Args>
+ void Set(const char* format, Args... args) {
+ format_ = format;
+ args_count_ = sizeof...(args);
+ SetArgs(0, args...);
+ }
+
+ void operator=(const Label& other);
+
+ bool operator==(const Label& other) const;
+
+ std::string Formatted() const;
+ const char* format() const { return format_; }
+
+ private:
+ void SetArgs(int position, int arg0) { args_[position] = arg0; }
+
+ template <typename... Args>
+ void SetArgs(int position, int arg0, Args... args) {
+ SetArgs(position, arg0);
+ SetArgs(position + 1, args...);
+ }
+
+ static constexpr int kMaxArgs = 4;
+ const char* format_ = nullptr;
+ int args_count_ = 0;
+ int args_[kMaxArgs];
+};
+
+namespace detail {
+
+// Forward-declaration, see class ThreadStack below.
+class ThreadStack;
+
+bool& GlobalIsProfilerRunning();
+
+// Returns the global vector of pointers to all stacks, there being one stack
+// per thread executing instrumented code.
+std::vector<ThreadStack*>* GlobalAllThreadStacks();
+
+// Returns the mutex to be locked around any access to GlobalAllThreadStacks().
+std::mutex* GlobalsMutex();
+
+// Returns the thread-local stack, specific to the current thread.
+ThreadStack* ThreadLocalThreadStack();
+
+// This 'stack' is what may be more appropriately called a 'pseudostack':
+// It contains Label entries that are 'manually' entered by instrumentation
+// code. It's unrelated to real call stacks.
+struct Stack {
+ std::uint32_t id = 0;
+ static constexpr int kMaxSize = 64;
+ int size = 0;
+ Label labels[kMaxSize];
+};
+
+// Returns the buffer byte size required by CopyToSample.
+int GetBufferSize(const Stack& stack);
+
+// Copies this Stack into a byte buffer, called a 'sample'.
+void CopyToBuffer(const Stack& stack, char* dst);
+
+// Populates this Stack from an existing sample buffer, typically
+// produced by CopyToSample.
+void ReadFromBuffer(const char* src, Stack* stack);
+
+// ThreadStack is meant to be used as a thread-local singleton, assigning to
+// each thread a Stack object holding its pseudo-stack of profile labels,
+// plus a mutex allowing to synchronize accesses to this pseudo-stack between
+// this thread and a possible profiler thread sampling it.
+class ThreadStack {
+ public:
+ ThreadStack();
+ ~ThreadStack();
+
+ const Stack& stack() const { return stack_; }
+
+ // Returns the mutex to lock around any access to this stack. Each stack is
+ // accessed by potentially two threads: the thread that it belongs to
+ // (which calls Push and Pop) and the profiler thread during profiling
+ // (which calls CopyToSample).
+ std::mutex& Mutex() const { return mutex_; }
+
+ // Pushes a new label on the top of this Stack.
+ template <typename... Args>
+ void Push(Args... args) {
+ // This mutex locking is needed to guard against race conditions as both
+ // the current thread and the profiler thread may be concurrently accessing
+ // this stack. In addition to that, this mutex locking also serves the other
+ // purpose of acting as a barrier (of compiler code reordering, of runtime
+ // CPU instruction reordering, and of memory access reordering), which
+ // gives a measure of correctness to this profiler. The downside is some
+ // latency. As this lock will be uncontended most of the times, the cost
+ // should be roughly that of an sequentially-consistent atomic access,
+ // comparable to an access to the level of CPU data cache that is shared
+ // among all cores, typically 60 cycles on current ARM CPUs, plus side
+ // effects from barrier instructions.
+ std::lock_guard<std::mutex> lock(mutex_);
+ // Avoid overrunning the stack, even in 'release' builds. This profiling
+ // instrumentation code should not ship in release builds anyway, the
+ // overhead of this check is negligible, and overrunning a stack array would
+ // be bad.
+ if (stack_.size >= Stack::kMaxSize) {
+ abort();
+ }
+ stack_.labels[stack_.size++].Set(args...);
+ }
+
+ // Pops the top-most label from this Stack.
+ void Pop() {
+ // See the comment in Push about this lock. While it would be tempting to
+ // try to remove this lock and just atomically decrement size_ with a
+ // store-release, that would not necessarily be a substitute for all of the
+ // purposes that this lock serves, or if it was done carefully to serve all
+ // of the same purposes, then that wouldn't be faster than this (mostly
+ // uncontended) lock.
+ std::lock_guard<std::mutex> lock(mutex_);
+ stack_.size--;
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ Stack stack_;
+};
+
+} // namespace detail
+
+// RAII user-facing way to construct Labels associated with their life scope
+// and get them pushed to / popped from the current thread stack.
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) {
+ thread_stack_->Push(args...);
+ }
+
+ ~ScopeLabel() { thread_stack_->Pop(); }
+
+ private:
+ detail::ThreadStack* thread_stack_;
+};
+
+#else // no RUY_PROFILER
+
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ explicit ScopeLabel(Args...) {}
+
+ // This destructor is needed to consistently silence clang's -Wunused-variable
+ // which seems to trigger semi-randomly.
+ ~ScopeLabel() {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc
new file mode 100644
index 0000000..ae3a2e2
--- /dev/null
+++ b/ruy/profiler/profiler.cc
@@ -0,0 +1,109 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/profiler.h"
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono> // NOLINT
+#include <cstdio>
+#include <cstdlib>
+#include <thread> // NOLINT
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+ScopeProfile::ScopeProfile() { Start(); }
+ScopeProfile::ScopeProfile(bool enable) {
+ if (enable) {
+ Start();
+ }
+}
+ScopeProfile::~ScopeProfile() {
+ if (!thread_) {
+ return;
+ }
+ finishing_.store(true);
+ thread_->join();
+ Finish();
+}
+
+void ScopeProfile::Start() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler already running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = true;
+ }
+ finishing_ = false;
+ thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this));
+}
+
+void ScopeProfile::ThreadFunc() {
+ while (!finishing_.load()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ auto* thread_stacks = detail::GlobalAllThreadStacks();
+ for (detail::ThreadStack* thread_stack : *thread_stacks) {
+ Sample(*thread_stack);
+ }
+ }
+}
+
+void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) {
+ std::lock_guard<std::mutex> lock(thread_stack.Mutex());
+ // Drop empty stacks.
+ // This ensures that profiles aren't polluted by uninteresting threads.
+ if (thread_stack.stack().size == 0) {
+ return;
+ }
+ int sample_size = detail::GetBufferSize(thread_stack.stack());
+ int old_buf_size = samples_buf_.size();
+ samples_buf_.resize(old_buf_size + sample_size);
+ detail::CopyToBuffer(thread_stack.stack(),
+ samples_buf_.data() + old_buf_size);
+}
+
+void ScopeProfile::Finish() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (!detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler is not running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = false;
+ }
+ if (user_treeview_) {
+ user_treeview_->Populate(samples_buf_);
+ } else {
+ TreeView treeview;
+ treeview.Populate(samples_buf_);
+ Print(treeview);
+ }
+}
+
+#endif // RUY_PROFILER
+
+} // namespace profiler
+} // namespace ruy
diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h
new file mode 100644
index 0000000..b68ca90
--- /dev/null
+++ b/ruy/profiler/profiler.h
@@ -0,0 +1,106 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
+
+#include <cstdio>
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono>
+#include <thread>
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// RAII user-facing way to create a profiler and let it profile a code scope,
+// and print out an ASCII/MarkDown treeview upon leaving the scope.
+class ScopeProfile {
+ public:
+ // Default constructor, unconditionally profiling.
+ ScopeProfile();
+
+ // Constructor allowing to choose at runtime whether to profile.
+ explicit ScopeProfile(bool enable);
+
+ // Destructor. It's where the profile is reported.
+ ~ScopeProfile();
+
+ // See treeview_ member.
+ void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; }
+
+ private:
+ void Start();
+
+ // Thread entry point function for the profiler thread. This thread is
+ // created on construction.
+ void ThreadFunc();
+
+ // Record a stack as a sample.
+ void Sample(const detail::ThreadStack& stack);
+
+ // Finalize the profile. Called on destruction.
+ // If user_treeview_ is non-null, it will receive the treeview.
+ // Otherwise the treeview will just be printed.
+ void Finish();
+
+ // Buffer where samples are recorded during profiling.
+ std::vector<char> samples_buf_;
+
+ // Used to synchronize thread termination.
+ std::atomic<bool> finishing_;
+
+ // Underlying profiler thread, which will perform the sampling.
+ // This profiler approach relies on a thread rather than on signals.
+ std::unique_ptr<std::thread> thread_;
+
+ // TreeView to populate upon destruction. If left null (the default),
+ // a temporary treeview will be used and dumped on stdout. The user
+ // may override that by passing their own TreeView object for other
+ // output options or to directly inspect the TreeView.
+ TreeView* user_treeview_ = nullptr;
+};
+
+#else // no RUY_PROFILER
+
+struct ScopeProfile {
+ ScopeProfile() {
+#ifdef GEMMLOWP_PROFILING
+ fprintf(
+ stderr,
+ "\n\n\n**********\n\nWARNING:\n\nLooks like you defined "
+ "GEMMLOWP_PROFILING, but this code has been ported to the new ruy "
+ "profiler replacing the old gemmlowp profiler. You should now be "
+ "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using "
+ "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n");
+#endif
+ }
+ explicit ScopeProfile(bool) {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc
new file mode 100644
index 0000000..e94840b
--- /dev/null
+++ b/ruy/profiler/test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <chrono>
+#include <random>
+#include <thread>
+
+#include "testing/base/public/gunit.h"
+#include "ruy/profiler/profiler.h"
+#include "ruy/profiler/test_instrumented_library.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+namespace {
+
+void DoSomeMergeSort(int size) {
+ std::vector<int> data(size);
+
+ std::default_random_engine engine;
+ for (auto& val : data) {
+ val = engine();
+ }
+
+ MergeSort(size, data.data());
+}
+
+// The purpose of this basic test is to cover the basic path that will be taken
+// by a majority of users, not inspecting treeviews but just implicitly printing
+// them on stdout, and to have this test enabled even when RUY_PROFILER is not
+// defined, so that we have coverage for the non-RUY_PROFILER case.
+TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) {
+ {
+ ScopeProfile profile;
+ DoSomeMergeSort(1 << 20);
+ }
+}
+
+#ifdef RUY_PROFILER
+
+TEST(ProfilerTest, MergeSortSingleThread) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ DoSomeMergeSort(1 << 20);
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 1);
+ const auto& thread_root = *treeview.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(thread_root), 22);
+ EXPECT_GE(
+ WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"),
+ 0.1 * thread_root.weight);
+ EXPECT_GE(WeightBelowNodeMatchingFormatted(
+ thread_root, "MergeSortRecurse (level=20, size=1)"),
+ 0.01 * thread_root.weight);
+
+ TreeView treeview_collapsed;
+ CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)",
+ &treeview_collapsed);
+ Print(treeview_collapsed);
+ const auto& collapsed_thread_root =
+ *treeview_collapsed.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6);
+ EXPECT_EQ(
+ WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"),
+ WeightBelowNodeMatchingUnformatted(collapsed_thread_root,
+ "MergeSort (size=%d)"));
+}
+
+TEST(ProfilerTest, MemcpyFourThreads) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ std::vector<std::unique_ptr<std::thread>> threads;
+ for (int i = 0; i < 4; i++) {
+ threads.emplace_back(new std::thread([i]() {
+ ScopeLabel thread_label("worker thread #%d", i);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ ScopeLabel some_more_work_label("some more work");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }));
+ }
+ for (int i = 0; i < 4; i++) {
+ threads[i]->join();
+ }
+ }
+ Print(treeview);
+ // Since we cleared GlobalAllThreadStacks and the current thread hasn't
+ // created any ScopeLabel, only the 4 worker threads should be recorded.
+ EXPECT_EQ(treeview.thread_roots().size(), 4);
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root_node = *thread_root.second;
+ // The root node may have 1 or 2 children depending on whether there is
+ // an "[other]" child.
+ EXPECT_GE(root_node.children.size(), 1);
+ EXPECT_LE(root_node.children.size(), 2);
+ const TreeView::Node& child_node = *root_node.children[0];
+ EXPECT_EQ(child_node.label.format(), "worker thread #%d");
+ // There must be 2 children, since roughly half the time will be in
+ // "some more work" leaving the other half in "[other]".
+ EXPECT_EQ(child_node.children.size(), 2);
+ const TreeView::Node& child_child_node = *child_node.children[0];
+ // Since we sample every millisecond and the threads run for >= 2000
+ // milliseconds, the "thread func" label should get roughly 2000 samples.
+ // Not very rigorous, as we're depending on the profiler thread getting
+ // scheduled, so to avoid this test being flaky, we use a much more
+ // conservative value of 500, one quarter of that normal value 2000.
+ EXPECT_GE(child_node.weight, 500);
+ // Likewise, allow up to four times more than the normal value 2000.
+ EXPECT_LE(child_node.weight, 8000);
+ // Roughly half of time should be spent under the "some more work" label.
+ float some_more_work_percentage =
+ 100.f * child_child_node.weight / child_node.weight;
+ EXPECT_GE(some_more_work_percentage, 40.0f);
+ EXPECT_LE(some_more_work_percentage, 60.0f);
+ }
+}
+
+TEST(ProfilerTest, OneThreadAfterAnother) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 0");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 1");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 2);
+}
+
+#endif // RUY_PROFILER
+
+} // namespace
+} // namespace profiler
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc
new file mode 100644
index 0000000..b017ea9
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.cc
@@ -0,0 +1,59 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace {
+
+void MergeSortRecurse(int level, int size, int* data, int* workspace) {
+ ruy::profiler::ScopeLabel function_label(
+ "MergeSortRecurse (level=%d, size=%d)", level, size);
+ if (size <= 1) {
+ return;
+ }
+ int half_size = size / 2;
+ MergeSortRecurse(level + 1, half_size, data, workspace);
+ MergeSortRecurse(level + 1, size - half_size, data + half_size,
+ workspace + half_size);
+
+ ruy::profiler::ScopeLabel merging_sorted_halves_label(
+ "Merging sorted halves");
+ int dst_index = 0;
+ int left_index = 0;
+ int right_index = half_size;
+ while (dst_index < size) {
+ int val;
+ if (left_index < half_size &&
+ ((right_index >= size) || data[left_index] < data[right_index])) {
+ val = data[left_index++];
+ } else {
+ val = data[right_index++];
+ }
+ workspace[dst_index++] = val;
+ }
+ for (int i = 0; i < size; i++) {
+ data[i] = workspace[i];
+ }
+}
+
+} // namespace
+
+void MergeSort(int size, int* data) {
+ ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size);
+ std::vector<int> workspace(size);
+ MergeSortRecurse(0, size, data, workspace.data());
+}
diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h
new file mode 100644
index 0000000..53d204e
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+
+#include "ruy/profiler/instrumentation.h"
+
+void MergeSort(int size, int* data);
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc
new file mode 100644
index 0000000..48d922a
--- /dev/null
+++ b/ruy/profiler/treeview.cc
@@ -0,0 +1,248 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef RUY_PROFILER
+
+#include "ruy/profiler/treeview.h"
+
+#include <algorithm>
+#include <cstdio>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace ruy {
+namespace profiler {
+
+namespace {
+
+void SortNode(TreeView::Node* node) {
+ using NodePtr = std::unique_ptr<TreeView::Node>;
+ std::sort(node->children.begin(), node->children.end(),
+ [](const NodePtr& n1, const NodePtr& n2) {
+ return n1->weight > n2->weight;
+ });
+ for (const auto& child : node->children) {
+ SortNode(child.get());
+ }
+}
+
+// Records a stack i.e. a sample in a treeview, by incrementing the weights
+// of matching existing nodes and/or by creating new nodes as needed,
+// recursively, below the given node.
+void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) {
+ node->weight++;
+ if (stack.size == level) {
+ return;
+ }
+ TreeView::Node* child_to_add_to = nullptr;
+ for (const auto& child : node->children) {
+ if (child->label == stack.labels[level]) {
+ child_to_add_to = child.get();
+ break;
+ }
+ }
+ if (!child_to_add_to) {
+ child_to_add_to = node->children.emplace_back(new TreeView::Node).get();
+ child_to_add_to->label = stack.labels[level];
+ }
+ AddStack(stack, child_to_add_to, level + 1);
+}
+
+// Recursively populates the treeview below the given node with 'other'
+// entries documenting for each node the difference between its weight and the
+// sum of its children's weight.
+void AddOther(TreeView::Node* node) {
+ int top_level_children_weight = 0;
+ for (const auto& child : node->children) {
+ AddOther(child.get());
+ top_level_children_weight += child->weight;
+ }
+ if (top_level_children_weight != 0 &&
+ top_level_children_weight != node->weight) {
+ const auto& new_child = node->children.emplace_back(new TreeView::Node);
+ new_child->label = Label("[other]");
+ new_child->weight = node->weight - top_level_children_weight;
+ }
+}
+
+} // namespace
+
+void TreeView::Populate(const std::vector<char>& samples_buf_) {
+ thread_roots_.clear();
+ // Populate the treeview with regular nodes coming from samples.
+ const char* buf_ptr = samples_buf_.data();
+ const char* const buf_ptr_end = buf_ptr + samples_buf_.size();
+ while (buf_ptr < buf_ptr_end) {
+ detail::Stack stack;
+ detail::ReadFromBuffer(buf_ptr, &stack);
+ // Empty stacks should have been dropped during sampling.
+ assert(stack.size > 0);
+ buf_ptr += GetBufferSize(stack);
+ const int id = stack.id;
+ if (!thread_roots_[id]) {
+ thread_roots_[id].reset(new Node);
+ }
+ AddStack(stack, thread_roots_[id].get(), 0);
+ }
+ // Populate the treeview with additional 'other' nodes, sort, and set
+ // root labels.
+ for (const auto& thread_root : thread_roots_) {
+ std::uint32_t id = thread_root.first;
+ Node* root = thread_root.second.get();
+ AddOther(root);
+ SortNode(root);
+ root->label.Set("Thread %x (%d samples)", id, root->weight);
+ }
+}
+
+// Recursively prints the treeview below the given node. The 'root' node
+// argument is only needed to compute weights ratios, with the root ratio
+// as denominator.
+void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root,
+ int level) {
+ if (&node == &root) {
+ printf("%s\n\n", node.label.Formatted().c_str());
+ } else {
+ for (int i = 1; i < level; i++) {
+ printf(" ");
+ }
+ printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight,
+ node.label.Formatted().c_str());
+ }
+ for (const auto& child : node.children) {
+ PrintTreeBelow(*child, root, level + 1);
+ }
+}
+
+void Print(const TreeView& treeview) {
+ printf("\n");
+ printf("Profile (%d threads):\n\n",
+ static_cast<int>(treeview.thread_roots().size()));
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root = *thread_root.second;
+ PrintTreeBelow(root, root, 0);
+ printf("\n");
+ }
+}
+
+int DepthOfTreeBelow(const TreeView::Node& node) {
+ if (node.children.empty()) {
+ return 0;
+ } else {
+ int max_child_depth = 0;
+ for (const auto& child : node.children) {
+ max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child));
+ }
+ return 1 + max_child_depth;
+ }
+}
+
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node,
+ const std::function<bool(const Label&)>& match) {
+ int weight = 0;
+ if (match(node.label)) {
+ weight += node.weight;
+ }
+ for (const auto& child : node.children) {
+ weight += WeightBelowNodeMatchingFunction(*child, match);
+ }
+ return weight;
+}
+
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&format](const Label& label) { return label.format() == format; });
+}
+
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ });
+}
+
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out) {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+ if (depth > 0) {
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseNode(*child_in, depth - 1, child_out);
+ }
+ }
+}
+
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out) {
+ if (match(node_in.label)) {
+ CollapseNode(node_in, depth, node_out);
+ } else {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out);
+ }
+ }
+}
+
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out) {
+ treeview_out->mutable_thread_roots()->clear();
+ for (const auto& thread_root_in : treeview_in.thread_roots()) {
+ std::uint32_t id = thread_root_in.first;
+ const auto& root_in = *thread_root_in.second;
+ auto* root_out = new TreeView::Node;
+ treeview_out->mutable_thread_roots()->emplace(id, root_out);
+ CollapseSubnodesMatchingFunction(root_in, depth, match, root_out);
+ }
+}
+
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&format](const Label& label) { return label.format() == format; },
+ treeview_out);
+}
+
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ },
+ treeview_out);
+}
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h
new file mode 100644
index 0000000..e34b4f9
--- /dev/null
+++ b/ruy/profiler/treeview.h
@@ -0,0 +1,130 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_
+
+#ifdef RUY_PROFILER
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+namespace profiler {
+
+// A tree view of a profile.
+class TreeView {
+ public:
+ struct Node {
+ std::vector<std::unique_ptr<Node>> children;
+ Label label;
+ int weight = 0;
+ };
+
+ void Populate(const std::vector<char>& samples_buf_);
+
+ // Intentionally an *ordered* map so that threads are enumerated
+ // in an order that's consistent and typically putting the 'main thread'
+ // first.
+ using ThreadRootsMap = std::map<std::uint32_t, std::unique_ptr<Node>>;
+
+ const ThreadRootsMap& thread_roots() const { return thread_roots_; }
+ ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; }
+
+ private:
+ ThreadRootsMap thread_roots_;
+};
+
+/* Below are API functions for manipulating and printing treeviews. */
+
+// Prints the treeview to stdout.
+void Print(const TreeView& treeview);
+
+// Prints the treeview below the given node on stdout.
+void PrintTreeBelow(const TreeView::Node& node);
+
+// Returns the tree depth below the given node.
+int DepthOfTreeBelow(const TreeView::Node& node);
+
+// Returns the sum of weights of nodes below the given node and filtered by
+// the `match` predicate.
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node, const std::function<bool(const Label&)>& match);
+
+// Returns the sum of weights of nodes below the given node and whose
+// unformatted label (i.e. raw format string) matches the given `format` string.
+//
+// This allows to aggregate nodes whose labels differ only by parameter values.
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format);
+
+// Returns the sum of weights of nodes below the given node and whose formatted
+// label matches the `formatted` string.
+//
+// In the case of nodes with parametrized labels, this allows to count only
+// nodes with specific parameter values. For that purpose, one may also instead
+// use WeightBelowNodeMatchingFunction directly, with a `match` predicate
+// comparing raw integer parameter values directly, instead of going through
+// formatted strings.
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted);
+
+// Produces a `node_out` that is a copy of `node_in` but with tree depth below
+// it clamped at `depth`, with further subtrees aggregated into single leaf
+// nodes.
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every subnode filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every node filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching unformatted labels,
+// i.e. raw format strings.
+// See the comment on WeightBelowNodeMatchingUnformatted.
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching formatted labels.
+// See the comment on WeightBelowNodeMatchingFormatted.
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out);
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_
diff --git a/ruy/ruy.h b/ruy/ruy.h
new file mode 100644
index 0000000..9cafe14
--- /dev/null
+++ b/ruy/ruy.h
@@ -0,0 +1,42 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is the only Ruy header that users should #include.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_
+
+#include "ruy/context.h"
+#include "ruy/dispatch.h"
+#include "ruy/matrix.h"
+#include "ruy/path.h"
+#include "ruy/spec.h"
+
+namespace ruy {
+
+// Performs a multiplication of matrices. This is Ruy's only API entry point.
+// Should be self-explanatory given the above documentation for each of Matrix,
+// Spec and Context.
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Context* context, Matrix<DstScalar>* dst) {
+ DispatchMul<CompiledPaths, LhsScalar, RhsScalar, DstScalar, Spec>(
+ lhs, rhs, spec, context, dst);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_
diff --git a/ruy/ruy_advanced.h b/ruy/ruy_advanced.h
new file mode 100644
index 0000000..124ddd2
--- /dev/null
+++ b/ruy/ruy_advanced.h
@@ -0,0 +1,69 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_
+
+#include <cstddef>
+#include <functional>
+
+#include "ruy/context.h"
+#include "ruy/matrix.h"
+#include "ruy/path.h"
+#include "ruy/prepack.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+
+// Low-level, explicit pre-packing API.
+//
+// The cost of packing an input matrix (either the LHS or RHS) is amortized
+// across the non-depth dimension of the opposite input matrix. Thus, when the
+// LHS has very few rows or the RHS has very few columns, the cost of packing
+// the opposite input matrix can become significant. See pack.h for further
+// information on packing.
+//
+// This file provides an API allowing a user to explicitly pack a matrix and
+// reuse the pre-packed matrix, avoiding that cost.
+//
+// See example_prepack.cc for example usage.
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void PrePackForMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, Context* context, Matrix<DstScalar>* dst,
+ PrepackedMatrix* prepacked_lhs,
+ PrepackedMatrix* prepacked_rhs,
+ std::function<void*(std::size_t)> alloc_fn) {
+ SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
+ PrePackForMulInternal<CompiledPaths>(lhs, rhs, spec, context, dst, prepacked,
+ alloc_fn);
+}
+
+template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
+ typename DstScalar, typename Spec>
+void MulWithPrepacked(const Matrix<LhsScalar>& lhs,
+ const Matrix<RhsScalar>& rhs, const Spec& spec,
+ Context* context, Matrix<DstScalar>* dst,
+ PrepackedMatrix* prepacked_lhs,
+ PrepackedMatrix* prepacked_rhs) {
+ SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
+ MulWithPrepackedInternal<CompiledPaths>(lhs, rhs, spec, context, dst,
+ prepacked);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_
diff --git a/ruy/ruy_test.bzl b/ruy/ruy_test.bzl
new file mode 100644
index 0000000..ef7e8b1
--- /dev/null
+++ b/ruy/ruy_test.bzl
@@ -0,0 +1,34 @@
+# Provides the ruy_test macro for type-parametrized tests.
+"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination."""
+
+def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None):
+ for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
+ native.cc_test(
+ name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
+ srcs = srcs,
+ copts = copts + [
+ "-DRUY_TEST_LHSSCALAR=%s" % lhs,
+ "-DRUY_TEST_RHSSCALAR=%s" % rhs,
+ "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
+ "-DRUY_TEST_DSTSCALAR=%s" % dst,
+ ],
+ deps = deps,
+ tags = tags,
+ )
+
+def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None):
+ tags = ["req_dep=//third_party/gemmlowp:profiler"]
+ for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst:
+ native.cc_binary(
+ name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst),
+ testonly = True,
+ srcs = srcs,
+ copts = copts + [
+ "-DRUY_TEST_LHSSCALAR=%s" % lhs,
+ "-DRUY_TEST_RHSSCALAR=%s" % rhs,
+ "-DRUY_TEST_ACCUMSCALAR=%s" % accum,
+ "-DRUY_TEST_DSTSCALAR=%s" % dst,
+ ],
+ deps = deps,
+ tags = tags,
+ )
diff --git a/ruy/ruy_test_ext.bzl b/ruy/ruy_test_ext.bzl
new file mode 100644
index 0000000..263121f
--- /dev/null
+++ b/ruy/ruy_test_ext.bzl
@@ -0,0 +1,19 @@
+"""Allows to specialize the ruy BUILD to availability of external libraries"""
+
+def ruy_test_ext_defines():
+ return select({
+ "//tools/cc_target_os:windows": [],
+ "//tools/cc_target_os:wasm": [],
+ "//tools/cc_target_os:chromiumos": ["RUY_TESTING_ON_CHROMIUMOS"],
+ "//conditions:default": ["RUY_TEST_EXTERNAL_PATHS"],
+ })
+
+def ruy_test_ext_deps():
+ return select({
+ "//tools/cc_target_os:windows": [],
+ "//conditions:default": [
+ "//third_party/eigen3",
+ "//third_party/gemmlowp",
+ "//third_party/lapack:blas",
+ ],
+ })
diff --git a/ruy/ruy_test_ext.bzl.opensource b/ruy/ruy_test_ext.bzl.opensource
new file mode 100644
index 0000000..5701fff
--- /dev/null
+++ b/ruy/ruy_test_ext.bzl.opensource
@@ -0,0 +1,7 @@
+"""Allows to specialize the ruy BUILD to availability of external libraries"""
+
+def ruy_test_ext_defines():
+ return []
+
+def ruy_test_ext_deps():
+ return []
diff --git a/ruy/side_pair.h b/ruy/side_pair.h
new file mode 100644
index 0000000..e62968b
--- /dev/null
+++ b/ruy/side_pair.h
@@ -0,0 +1,64 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_
+
+#include "ruy/check_macros.h"
+
+namespace ruy {
+
+// Enumeration of the sides, i.e. the operands 'slots', in a matrix
+// multiplication. The numerical values of these enumeration constants matter
+// because these will be used as indices into the array underlying a SidePair.
+enum class Side {
+ // Left-hand side
+ kLhs = 0,
+ // Right-hand side
+ kRhs = 1
+};
+
+// SidePair is a pair container where the two elements are indexed by a Side
+// enum.
+template <typename T>
+class SidePair final {
+ public:
+ SidePair() {}
+ SidePair(const T& a, const T& b) : elem_{a, b} {}
+ const T& operator[](Side side) const {
+ const int index = static_cast<int>(side);
+ // Technically this check is vacuous, since other values would be
+ // out-of-range for enum Side.
+ RUY_DCHECK(index == 0 || index == 1);
+ return elem_[index];
+ }
+
+ T& operator[](Side side) {
+ const int index = static_cast<int>(side);
+ // Technically this check is vacuous, since other values would be
+ // out-of-range for enum Side.
+ RUY_DCHECK(index == 0 || index == 1);
+ return elem_[index];
+ }
+
+ private:
+ static_assert(static_cast<int>(Side::kLhs) == 0, "");
+ static_assert(static_cast<int>(Side::kRhs) == 1, "");
+ T elem_[2];
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_
diff --git a/ruy/size_util.h b/ruy/size_util.h
new file mode 100644
index 0000000..2a4bdb9
--- /dev/null
+++ b/ruy/size_util.h
@@ -0,0 +1,93 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_
+
+#include <type_traits>
+
+#include "ruy/check_macros.h"
+
+#ifdef _WIN32
+#include <intrin.h>
+#endif
+
+namespace ruy {
+
+template <typename Integer>
+inline Integer floor_log2(Integer n) {
+ static_assert(std::is_integral<Integer>::value, "");
+ static_assert(std::is_signed<Integer>::value, "");
+ static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, "");
+
+ RUY_DCHECK_GE(n, 1);
+#ifdef _WIN32
+ unsigned long result; // NOLINT[runtime/int]
+ if (sizeof(Integer) == 4) {
+ _BitScanReverse(&result, n);
+ } else {
+ _BitScanReverse64(&result, n);
+ }
+ return result;
+#else
+ if (sizeof(Integer) == 4) {
+ return 31 - __builtin_clz(n);
+ } else {
+ return 63 - __builtin_clzll(n);
+ }
+#endif
+}
+
+template <typename Integer>
+Integer ceil_log2(Integer n) {
+ RUY_DCHECK_GE(n, 1);
+ return n == 1 ? 0 : floor_log2(n - 1) + 1;
+}
+
+template <typename Integer>
+bool is_pot(Integer value) {
+ return (value > 0) && ((value & (value - 1)) == 0);
+}
+
+template <typename Integer>
+Integer pot_log2(Integer n) {
+ RUY_DCHECK(is_pot(n));
+ return floor_log2(n);
+}
+
+template <typename Integer>
+Integer round_down_pot(Integer value) {
+ return static_cast<Integer>(1) << floor_log2(value);
+}
+
+template <typename Integer>
+Integer round_up_pot(Integer value) {
+ return static_cast<Integer>(1) << ceil_log2(value);
+}
+
+template <typename Integer, typename Modulo>
+Integer round_down_pot(Integer value, Modulo modulo) {
+ RUY_DCHECK_EQ(modulo & (modulo - 1), 0);
+ return value & ~(modulo - 1);
+}
+
+template <typename Integer, typename Modulo>
+Integer round_up_pot(Integer value, Modulo modulo) {
+ return round_down_pot(value + modulo - 1, modulo);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_
diff --git a/ruy/size_util_test.cc b/ruy/size_util_test.cc
new file mode 100644
index 0000000..54f0c11
--- /dev/null
+++ b/ruy/size_util_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/size_util.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <limits>
+
+#include "testing/base/public/gunit.h"
+
+namespace ruy {
+namespace {
+
+template <typename Integer>
+void SizeUtilTestValue(Integer value) {
+ if (value == 0) {
+ return;
+ }
+
+ EXPECT_LE(0, floor_log2(value));
+ EXPECT_LE(floor_log2(value), ceil_log2(value));
+ EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer));
+
+ if (is_pot(value)) {
+ EXPECT_EQ(floor_log2(value), ceil_log2(value));
+ EXPECT_EQ(floor_log2(value), pot_log2(value));
+ } else {
+ EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value));
+ }
+ EXPECT_EQ(value >> floor_log2(value), 1);
+ EXPECT_EQ(round_down_pot(value), static_cast<Integer>(1)
+ << floor_log2(value));
+ EXPECT_LE(round_down_pot(value), value);
+ EXPECT_GE(round_down_pot(value), value >> 1);
+ EXPECT_TRUE(is_pot(round_down_pot(value)));
+
+ if (ceil_log2(value) < 8 * sizeof(Integer) - 1) {
+ EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0);
+ EXPECT_EQ(round_up_pot(value), static_cast<Integer>(1) << ceil_log2(value));
+ EXPECT_GE(round_up_pot(value), value);
+ EXPECT_LE(round_up_pot(value) >> 1, value);
+ EXPECT_TRUE(is_pot(round_up_pot(value)));
+ }
+
+ for (std::uint8_t modulo : {1, 2, 8, 32, 128}) {
+ EXPECT_GE(value, round_down_pot(value, modulo));
+ EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0);
+
+ if (value <= std::numeric_limits<Integer>::max() - modulo) {
+ EXPECT_LE(value, round_up_pot(value, modulo));
+ EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0);
+ }
+ }
+}
+
+template <typename Integer>
+void SizeUtilTest() {
+ for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) {
+ const Integer pot = static_cast<Integer>(1) << exponent;
+ SizeUtilTestValue(pot - 1);
+ SizeUtilTestValue(pot);
+ SizeUtilTestValue(pot + 1);
+ SizeUtilTestValue(pot + 12);
+ SizeUtilTestValue(pot + 123);
+ }
+ SizeUtilTestValue(std::numeric_limits<Integer>::max() - 1);
+ SizeUtilTestValue(std::numeric_limits<Integer>::max());
+}
+
+TEST(SizeUtilTest, Int) { SizeUtilTest<int>(); }
+
+TEST(SizeUtilTest, Long) { SizeUtilTest<long int>(); } // NOLINT
+
+TEST(SizeUtilTest, LongLong) { SizeUtilTest<long long int>(); } // NOLINT
+
+TEST(SizeUtilTest, Int32) { SizeUtilTest<std::int32_t>(); }
+
+TEST(SizeUtilTest, Int64) { SizeUtilTest<std::int64_t>(); }
+
+TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest<std::ptrdiff_t>(); }
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/spec.h b/ruy/spec.h
new file mode 100644
index 0000000..d96b6a9
--- /dev/null
+++ b/ruy/spec.h
@@ -0,0 +1,118 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_
+
+#include <limits>
+#include <type_traits>
+
+#include "ruy/cpu_cache_size.h"
+#include "ruy/matrix.h"
+
+namespace ruy {
+
+// Our 'general' loop structure (the default) involves multi-threading and
+// complicated loops aiming to optimize cache-friendliness. One may opt out of
+// this and pick the 'simple' loop structure instead, which only performs well
+// for small matrix sizes and only allows using one thread, in exchange for
+// smaller code size.
+enum class LoopStructure { kGeneral, kSimple, kAuto };
+
+// In general we allow zero_point's to have any Scalar value. This is called
+// 'asymmetric' quantization. We do take advantage of the optimization
+// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the
+// int8 value 0 or the uint8 value 128), but we still generate code to handle
+// the general asymmetric case. By choosing kSymmetric here, one opts out of
+// this and supports only the symmetric case, in exchange for smaller code size.
+enum class ZeroPointSupport { kGeneral, kSymmetric };
+
+// In general we allow all Layout's, even if we may use slow paths for some
+// kinds of layouts. By choosing kRCC, one may opt out of this and
+// only keep support for the simplest and most efficient combination of
+// Layout's, in exchange for smaller code size. The case covered by
+// kRCC is where the storage orders are exactly the following:
+// - LHS is RowMajor
+// - RHS is ColMajor
+// - Destination is ColMajor
+enum class LayoutSupport { kGeneral, kRCC };
+
+// A Spec describes all about a matrix multiplication operation that isn't
+// encoded in the LHS, RHS and destination matrices. Some of that information
+// is encoded as compile-time constants and types (for instance, the choice
+// of accumulator type, AccumScalar). Some of that information is encoded as
+// runtime values (for instance, the optional bias vector).
+template <typename tAccumScalar, typename tDstScalar>
+struct BasicSpec {
+ // Accumulator type. The type of accumulators used to compute the dot-products
+ // before being ultimately casted to the destination type.
+ using AccumScalar = tAccumScalar;
+ // The destination scalar type.
+ using DstScalar = tDstScalar;
+ // The bias vector data, if not null.
+ const AccumScalar* bias = nullptr;
+ // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
+ // of the multiplier by which accumulators are multiplied before being casted
+ // to the destination type.
+ AccumScalar multiplier_fixedpoint = 0;
+ // Only for non-floating-point cases. The exponent part of the aforementioned
+ // multiplier.
+ int multiplier_exponent = 0;
+ // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
+ // point to a buffer of as many values as there are rows in the destination
+ // matrix. Each row of the destination matrix will use the corresponding
+ // buffer element instead of multiplier_fixedpoint.
+ const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
+ // Per-channel variant of multiplier_exponent. If not nullptr, this must
+ // point to a buffer of as many values as there are rows in the destination
+ // matrix. Each row of the destination matrix will use the corresponding
+ // buffer element instead of multiplier_exponent.
+ //
+ // Either none or both of multiplier_exponent_perchannel and
+ // multiplier_fixedpoint_perchannel must be nullptr.
+ const int* multiplier_exponent_perchannel = nullptr;
+ // min clamp bound of destination values.
+ DstScalar clamp_min = std::is_floating_point<DstScalar>::value
+ ? -std::numeric_limits<DstScalar>::infinity()
+ : std::numeric_limits<DstScalar>::lowest();
+ // max clamp bound of destination values.
+ DstScalar clamp_max = std::is_floating_point<DstScalar>::value
+ ? std::numeric_limits<DstScalar>::infinity()
+ : std::numeric_limits<DstScalar>::max();
+ // See above enum LoopStructure
+ static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto;
+ // See above enum LayoutSupport
+ static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral;
+ // See above enum ZeroPointSupport
+ static constexpr ZeroPointSupport kZeroPointSupport =
+ ZeroPointSupport::kGeneral;
+ // Testing-only, not meant to be used by actual users:
+ // Used for testing of various kernel layouts.
+ using StandardCppKernelLhsLayout = FixedKernelLayout<Order::kColMajor, 1, 1>;
+ using StandardCppKernelRhsLayout = FixedKernelLayout<Order::kColMajor, 1, 1>;
+ // Returns (a reasonable estimate of) the local CPU cache size.
+ // See ruy::LocalDataCacheSize() which returns some coarse, sane default for
+ // each CPU architecture.
+ // This may be overridden, either to provide more accurate/runtime values,
+ // or to test with other values to let testcases have more coverage.
+ static int local_data_cache_size() { return LocalDataCacheSize(); }
+ // Same as local_data_cache_size but for the total data cache size accessible
+ // to each CPU core. See ruy::SharedDataCacheSize().
+ static int shared_data_cache_size() { return SharedDataCacheSize(); }
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_
diff --git a/ruy/test.h b/ruy/test.h
new file mode 100644
index 0000000..649a0d9
--- /dev/null
+++ b/ruy/test.h
@@ -0,0 +1,2125 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_
+
+#include <math.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <ctime>
+#include <iostream>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <random>
+#include <set>
+#include <sstream>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <vector>
+
+#include "testing/base/public/gunit.h" // IWYU pragma: export
+#include "ruy/matrix.h" // IWYU pragma: export
+#include "ruy/platform.h"
+#include "ruy/pmu.h"
+#include "ruy/ruy.h"
+#include "ruy/ruy_advanced.h"
+#include "ruy/spec.h" // IWYU pragma: export
+#include "ruy/time.h"
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/gemmlowp/public/gemmlowp.h"
+#include "third_party/lapack/blas.h"
+#endif
+
+#ifdef RUY_PROFILER
+#include "ruy/profiler/profiler.h"
+#endif
+
+namespace ruy {
+
+const float kClampRatio = 0.1f;
+
+enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas };
+
+inline std::vector<std::string>* CoveredPaths() {
+ static std::vector<std::string> covered_paths;
+ return &covered_paths;
+}
+
+inline const char* PathName(Path path) {
+#define RUY_PATHNAME_CASE(NAME) \
+ case Path::NAME: \
+ return #NAME;
+ switch (path) {
+ RUY_PATHNAME_CASE(kReference)
+ RUY_PATHNAME_CASE(kStandardCpp)
+#if RUY_PLATFORM(NEON)
+ RUY_PATHNAME_CASE(kNeon)
+ RUY_PATHNAME_CASE(kNeonDotprod)
+#elif RUY_PLATFORM(X86)
+ RUY_PATHNAME_CASE(kSse42)
+ RUY_PATHNAME_CASE(kAvx2)
+ RUY_PATHNAME_CASE(kAvx512)
+ RUY_PATHNAME_CASE(kAvxVnni)
+#endif
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_PATHNAME_CASE
+}
+
+inline const char* TuningName(Tuning tuning) {
+#define RUY_SUBPATHNAME_CASE(NAME) \
+ case Tuning::NAME: \
+ return #NAME;
+ switch (tuning) {
+ RUY_SUBPATHNAME_CASE(kInOrder)
+ RUY_SUBPATHNAME_CASE(kOutOfOrder)
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_SUBPATHNAME_CASE
+}
+
+inline const char* PathName(ExternalPath path) {
+#define RUY_PATHNAME_CASE(NAME) \
+ case ExternalPath::NAME: \
+ return #NAME;
+ switch (path) {
+ RUY_PATHNAME_CASE(kGemmlowp)
+ RUY_PATHNAME_CASE(kEigen)
+ RUY_PATHNAME_CASE(kEigenTensor)
+ RUY_PATHNAME_CASE(kOpenBlas)
+ default:
+ RUY_CHECK(false);
+ return nullptr;
+ }
+#undef RUY_PATHNAME_CASE
+}
+
+inline std::ostream& operator<<(std::ostream& stream, Path path) {
+ return stream << PathName(path);
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ ExternalPath external_path) {
+ return stream << PathName(external_path);
+}
+
+template <typename ContainerType>
+std::string Join(const ContainerType& container) {
+ if (container.empty()) {
+ return "<empty>";
+ }
+ std::ostringstream stream;
+ auto it = container.begin();
+ stream << *it++;
+ for (; it != container.end(); ++it) {
+ stream << ", ";
+ stream << *it;
+ }
+ return stream.str();
+}
+
+struct LogCoveredPathsOnDestruction final {
+ ~LogCoveredPathsOnDestruction() {
+ std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl;
+
+ // When testing on ARM64 ChromiumOS emulator, make sure that we covered
+ // the dotprod path. We're getting such coverage at the moment thanks to
+ // using a sufficiently recent emulator, and we don't want to regress that.
+#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS
+ bool found_dotprod = false;
+ for (const std::string& covered_path : *CoveredPaths()) {
+ if (covered_path == "kNeonDotprod") {
+ found_dotprod = true;
+ }
+ }
+ if (!found_dotprod) {
+ std::cerr
+ << "Error: we haven't tested the kNeonDotprod path as we should "
+ "have. At the moment, this is required on ChromiumOS as this is "
+ "what we run emulator tests in, that currently supports "
+ "dot-product "
+ "instructions, and we care very much about not regressing that. "
+ "If this test was run in an emulator, please upgrade to a newer "
+ "emulator version. If this test was run on an actual device, and "
+ "you need to be able to run ruy tests on devices not supporting "
+ "dot-product instructions, get in touch with us.\n"
+ << std::endl;
+ abort();
+ }
+#endif
+ }
+ static void Singleton() { static LogCoveredPathsOnDestruction singleton; }
+};
+
+enum class RandomRange {
+ kGeneral,
+ kAvoidMinValue,
+ kOffCenterAvoidMinValue,
+ kReasonableSrcZeroPoint,
+ kReasonableDstZeroPoint,
+ kBias
+};
+
+template <typename Scalar,
+ bool IsFloatingPoint = std::is_floating_point<Scalar>::value>
+struct RandomRangeBounds {};
+
+template <typename Scalar>
+struct RandomRangeBounds<Scalar, true> {
+ static Scalar GetMinBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return -1;
+ case RandomRange::kAvoidMinValue:
+ return -1;
+ case RandomRange::kOffCenterAvoidMinValue:
+ return -1;
+ case RandomRange::kReasonableSrcZeroPoint:
+ return 0;
+ case RandomRange::kReasonableDstZeroPoint:
+ return 0;
+ case RandomRange::kBias:
+ return -1;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+ static Scalar GetMaxBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return 1;
+ case RandomRange::kAvoidMinValue:
+ return 1;
+ case RandomRange::kOffCenterAvoidMinValue:
+ return 1;
+ case RandomRange::kReasonableSrcZeroPoint:
+ return 0;
+ case RandomRange::kReasonableDstZeroPoint:
+ return 0;
+ case RandomRange::kBias:
+ return 1;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+};
+
+template <typename Scalar>
+Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) {
+ float sum = s1 * weight1 + s2 * weight2;
+ float clamped = std::min<float>(
+ std::numeric_limits<Scalar>::max(),
+ std::max<float>(std::numeric_limits<Scalar>::lowest(), sum));
+ return static_cast<Scalar>(clamped);
+}
+
+template <typename Scalar>
+Scalar Parametrized(float param) {
+ return WeightedSum(std::numeric_limits<Scalar>::max(), param,
+ std::numeric_limits<Scalar>::lowest(), 1 - param);
+}
+
+template <typename Scalar>
+struct RandomRangeBounds<Scalar, false> {
+ static Scalar GetMinBound(RandomRange range) {
+ static constexpr double offcenteredness =
+ 0.02; // Shift lower limit by about 5 for range of 255.
+ switch (range) {
+ case RandomRange::kGeneral:
+ return std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kAvoidMinValue:
+ return 1 + std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kOffCenterAvoidMinValue:
+ return 1 + std::numeric_limits<Scalar>::lowest() +
+ static_cast<Scalar>(
+ offcenteredness * std::numeric_limits<Scalar>::max() -
+ offcenteredness *
+ (std::numeric_limits<Scalar>::lowest() + 1));
+ case RandomRange::kReasonableSrcZeroPoint:
+ return std::numeric_limits<Scalar>::lowest();
+ case RandomRange::kReasonableDstZeroPoint:
+ return Parametrized<Scalar>(0.4);
+ case RandomRange::kBias:
+ return std::is_same<Scalar, std::int32_t>::value
+ ? static_cast<Scalar>(-10000)
+ : 0;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+ static Scalar GetMaxBound(RandomRange range) {
+ switch (range) {
+ case RandomRange::kGeneral:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kAvoidMinValue:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kOffCenterAvoidMinValue:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kReasonableSrcZeroPoint:
+ return std::numeric_limits<Scalar>::max();
+ case RandomRange::kReasonableDstZeroPoint:
+ return Parametrized<Scalar>(0.6);
+ case RandomRange::kBias:
+ return std::is_same<Scalar, std::int32_t>::value
+ ? static_cast<Scalar>(10000)
+ : 0;
+ default:
+ RUY_CHECK(false);
+ return 0;
+ }
+ }
+};
+
+inline std::default_random_engine& global_random_engine() {
+ static std::default_random_engine engine;
+ return engine;
+}
+
+template <typename Scalar>
+struct UniformRandomDistribution {
+ UniformRandomDistribution(RandomRange range)
+ : dist(RandomRangeBounds<Scalar>::GetMinBound(range),
+ RandomRangeBounds<Scalar>::GetMaxBound(range)) {}
+ Scalar Get() { return dist(global_random_engine()); }
+ // std::uniform_int_distribution is specified not to support char types,
+ // only short and wider types. MSVC actually generates an error on
+ // std::uniform_int_distribution<std::int8_t>.
+ using StdDistType = typename std::conditional<
+ std::is_floating_point<Scalar>::value,
+ std::uniform_real_distribution<Scalar>,
+ std::uniform_int_distribution<std::int32_t>>::type;
+ StdDistType dist;
+};
+
+template <typename Scalar>
+void MakeRandomScalar(UniformRandomDistribution<Scalar>* uniform_dist,
+ Scalar* dst) {
+ *dst = uniform_dist->Get();
+}
+
+template <typename Scalar>
+void MakeRandomVector(UniformRandomDistribution<Scalar>* uniform_dist, int size,
+ std::vector<Scalar>* dst) {
+ dst->resize(size);
+ for (auto& x : *dst) {
+ MakeRandomScalar(uniform_dist, &x);
+ }
+}
+
+template <typename Scalar>
+void MakeRandomScalar(RandomRange range, Scalar* dst) {
+ UniformRandomDistribution<Scalar> dist(range);
+ *dst = dist.Get();
+ if (range == RandomRange::kReasonableDstZeroPoint ||
+ range == RandomRange::kReasonableSrcZeroPoint) {
+ if (global_random_engine()() & 1) {
+ *dst = SymmetricZeroPoint<Scalar>();
+ }
+ }
+}
+
+template <typename Scalar>
+void MakeRandomVector(RandomRange range, int size, std::vector<Scalar>* dst) {
+ UniformRandomDistribution<Scalar> dist(range);
+ dst->resize(size);
+ for (auto& x : *dst) {
+ MakeRandomScalar(&dist, &x);
+ }
+}
+
+enum class LayoutStyle { kPackedLinear, kLinear };
+
+inline void MakeLayout(int rows, int cols, Order order,
+ LayoutStyle layout_style, Layout* layout) {
+ layout->rows = rows;
+ layout->cols = cols;
+ layout->order = order;
+
+ const int packed_stride = order == Order::kColMajor ? rows : cols;
+
+ RUY_CHECK(layout_style == LayoutStyle::kPackedLinear ||
+ layout_style == LayoutStyle::kLinear);
+ if (layout_style == LayoutStyle::kPackedLinear) {
+ layout->stride = packed_stride;
+ } else {
+ layout->stride = packed_stride + 1;
+ }
+}
+
+template <typename Scalar>
+struct StorageMatrix {
+ StorageMatrix() = default;
+ StorageMatrix(const StorageMatrix&) = delete;
+ void operator=(const StorageMatrix&) = delete;
+ std::vector<Scalar> data;
+ Matrix<Scalar> matrix;
+};
+
+template <typename Scalar>
+void VerifyConsistentFields(const StorageMatrix<Scalar>& storage_matrix) {
+ if (storage_matrix.data.empty()) {
+ RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr);
+ RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0);
+ RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0);
+ } else {
+ RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data());
+ RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout),
+ storage_matrix.data.size());
+ }
+}
+
+template <typename Scalar>
+void MakeRandom(int rows, int cols, Order order, Scalar zero_point,
+ LayoutStyle layout_style, RandomRange range,
+ StorageMatrix<Scalar>* storage_matrix) {
+ MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout);
+ storage_matrix->matrix.zero_point = zero_point;
+ UniformRandomDistribution<Scalar> data_dist(range);
+ MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout),
+ &storage_matrix->data);
+ storage_matrix->matrix.data = storage_matrix->data.data();
+ VerifyConsistentFields(*storage_matrix);
+}
+
+template <typename Scalar>
+struct TestResult {
+ void operator=(const TestResult&) = delete;
+ void operator=(const TestResult&&) = delete;
+ StorageMatrix<Scalar> storage_matrix;
+ Path path = Path::kNone;
+ Tuning tuning = Tuning::kAuto;
+ ExternalPath external_path = ExternalPath::kNone;
+ float latency;
+ float l1_refill_rate;
+ float l2_refill_rate;
+ float l3_refill_rate;
+ float l1tlb_refill_rate;
+ float l2tlb_refill_rate;
+ float mispred_rate;
+ float frontend_stall_rate;
+ float backend_stall_rate;
+
+ // Per-path data for pre-packing.
+ // This is not used by external paths or by Path::kReference.
+ Allocator allocator;
+ PrepackedMatrix prepacked_lhs;
+ PrepackedMatrix prepacked_rhs;
+ bool use_prepacked_lhs = false;
+ bool use_prepacked_rhs = false;
+};
+
+template <typename Scalar>
+std::string PathName(const TestResult<Scalar>& result) {
+ std::string pathname;
+ if (result.path != Path::kNone) {
+ pathname.assign(PathName(result.path));
+ } else if (result.external_path != ExternalPath::kNone) {
+ pathname.assign(PathName(result.external_path));
+ } else {
+ RUY_CHECK(false);
+ }
+ if (result.tuning != Tuning::kAuto) {
+ pathname.append("/");
+ pathname.append(TuningName(result.tuning));
+ }
+ return pathname;
+}
+
+enum class ExpectedOutcome { kSuccess, kDeath };
+
+template <typename tLhsScalar, typename tRhsScalar, typename SpecType>
+struct TestSet final {
+ using LhsScalar = tLhsScalar;
+ using RhsScalar = tRhsScalar;
+ using AccumScalar = typename SpecType::AccumScalar;
+ using DstScalar = typename SpecType::DstScalar;
+ using Spec = SpecType;
+ using TestResultType = TestResult<DstScalar>;
+
+ void Run() {
+ MakeZeroPoints();
+ MakeLhsRhs();
+ MakeSpec();
+ MakeOtherParams();
+ MakeResultPaths();
+ MakePrepackedMatrices();
+ Eval();
+ Verify();
+ }
+
+ private:
+ void MakeZeroPoints();
+ void MakeLhsRhs();
+ void MakeSpec();
+ void MakeResultPaths();
+ void MakePrepackedMatrices();
+ void MakeOtherParams();
+ void EvalAndVerify();
+ void Eval();
+ void Verify();
+
+ void EvalResult(TestResultType* result);
+ void EvalRuy(TestResultType* result);
+ void DoMul(TestResultType* result);
+ void Benchmark(TestResultType* result);
+ void VerifyTestResults() const;
+
+ public:
+ enum class LifeStage {
+ kInitial,
+ kHasZeroPoints,
+ kHasLhsRhs,
+ kHasSpec,
+ kHasOtherParams,
+ kHasResultPaths,
+ kHasPrepackedMatrices,
+ kEvaluated,
+ kFinal
+ };
+
+ ~TestSet() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kFinal);
+ LogCoveredPathsOnDestruction::Singleton();
+ }
+
+ LifeStage life_stage = LifeStage::kInitial;
+
+ int rows = 0;
+ int cols = 0;
+ int depth = 0;
+ Order lhs_order = Order::kRowMajor;
+ Order rhs_order = Order::kColMajor;
+ Order dst_order = Order::kColMajor;
+ LayoutStyle layout_style = LayoutStyle::kPackedLinear;
+ ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess;
+
+ bool use_specified_zero_points = false;
+ LhsScalar lhs_zero_point = 0;
+ RhsScalar rhs_zero_point = 0;
+ DstScalar dst_zero_point = 0;
+
+ std::vector<AccumScalar> per_channel_multiplier_fixedpoint;
+ std::vector<int> per_channel_multiplier_exponent;
+
+ StorageMatrix<LhsScalar> lhs;
+ StorageMatrix<RhsScalar> rhs;
+ Spec spec;
+ std::vector<AccumScalar> bias_data;
+ std::vector<std::unique_ptr<TestResultType>> results;
+
+ std::vector<Path> paths;
+ std::vector<ExternalPath> external_paths;
+
+ bool benchmark = false;
+ bool perchannel = false;
+ int max_num_threads = 0;
+ bool benchmark_prepack_lhs = false;
+ bool benchmark_prepack_rhs = false;
+};
+
+inline PmuEvents& GlobalPmuEvents() {
+ static PmuEvents pmu;
+ return pmu;
+}
+
+inline Context& GlobalContext() {
+ // Ensure that GlobalPmuEvents is constructed before we create any context.
+ // This ensures that pmu counters are opened before we create any worker
+ // thread, which is necessary to count events from worker threads.
+ GlobalPmuEvents();
+
+ static Context context;
+ return context;
+}
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define RUY_TSAN
+#endif
+#if __has_feature(address_sanitizer)
+#define RUY_ASAN
+#endif
+#endif // defined(__has_feature)
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::DoMul(TestResultType* result) {
+ Context* context = &GlobalContext();
+
+ if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) {
+ Mul<kAllPaths>(lhs.matrix, rhs.matrix, spec, context,
+ &result->storage_matrix.matrix);
+ return;
+ }
+
+ // If we prepacked an input matrix, null out its data pointer to check
+ // that we don't access any data through it.
+ Matrix<LhsScalar> null_data_lhs = lhs.matrix;
+ Matrix<RhsScalar> null_data_rhs = rhs.matrix;
+ if (result->use_prepacked_lhs) {
+ null_data_lhs.data = nullptr;
+ }
+ if (result->use_prepacked_rhs) {
+ null_data_rhs.data = nullptr;
+ }
+
+ // Do the multiplication with pre-packed matrices.
+ PrepackedMatrix* prepacked_lhs_ptr =
+ result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr;
+ PrepackedMatrix* prepacked_rhs_ptr =
+ result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr;
+ MulWithPrepacked<kAllPaths>(null_data_lhs, null_data_rhs, spec, context,
+ &result->storage_matrix.matrix, prepacked_lhs_ptr,
+ prepacked_rhs_ptr);
+}
+
+// When building for WAsm, ASSERT_DEATH is not defined.
+#ifdef ASSERT_DEATH
+#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE)
+#else
+#define RUY_ASSERT_DEATH(CONDITION, MESSAGE)
+#endif
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::EvalRuy(TestResultType* result) {
+ GlobalContext().explicit_tuning = result->tuning;
+ if (max_num_threads) {
+ GlobalContext().max_num_threads = max_num_threads;
+ } else if (benchmark) {
+ GlobalContext().max_num_threads = 1;
+ } else {
+ GlobalContext().max_num_threads = 1 + global_random_engine()() % 8;
+ }
+ GlobalContext().SetRuntimeEnabledPaths(result->path);
+ if (expected_outcome == ExpectedOutcome::kSuccess) {
+ DoMul(result);
+ RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path);
+ } else if (expected_outcome == ExpectedOutcome::kDeath) {
+ // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH.
+ // Report a bug?
+#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN)
+ RUY_ASSERT_DEATH(DoMul(result), "");
+#endif
+ } else {
+ RUY_CHECK(false);
+ }
+ GlobalContext().explicit_tuning = Tuning::kAuto;
+ GlobalContext().max_num_threads = 1;
+}
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+
+template <typename Scalar, gemmlowp::MapOrder tOrder>
+void WrapGemmlowp(const Matrix<Scalar>& src,
+ gemmlowp::MatrixMap<const Scalar, tOrder>* dst) {
+ RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor
+ ? Order::kColMajor
+ : Order::kRowMajor));
+ *dst = gemmlowp::MatrixMap<const Scalar, tOrder>(
+ src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride);
+}
+
+template <typename Scalar, gemmlowp::MapOrder tOrder>
+void WrapGemmlowpMutable(Matrix<Scalar>* src,
+ gemmlowp::MatrixMap<Scalar, tOrder>* dst) {
+ RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor
+ ? Order::kColMajor
+ : Order::kRowMajor));
+ *dst = gemmlowp::MatrixMap<Scalar, tOrder>(
+ src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride);
+}
+
+template <Order tOrder>
+struct GemmlowpOrder {};
+
+template <>
+struct GemmlowpOrder<Order::kColMajor> {
+ static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor;
+};
+
+template <>
+struct GemmlowpOrder<Order::kRowMajor> {
+ static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor;
+};
+
+inline gemmlowp::GemmContext& GlobalGemmlowpContext() {
+ static gemmlowp::GemmContext context;
+ return context;
+}
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename Spec>
+void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder =
+ GemmlowpOrder<LhsOrder>::kValue;
+ static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder =
+ GemmlowpOrder<RhsOrder>::kValue;
+ static constexpr gemmlowp::MapOrder kGemmlowpDstOrder =
+ GemmlowpOrder<DstOrder>::kValue;
+ gemmlowp::MatrixMap<const LhsScalar, kGemmlowpLhsOrder> gemmlowp_lhs;
+ gemmlowp::MatrixMap<const RhsScalar, kGemmlowpRhsOrder> gemmlowp_rhs;
+ gemmlowp::MatrixMap<DstScalar, kGemmlowpDstOrder> gemmlowp_dst;
+ WrapGemmlowp(lhs, &gemmlowp_lhs);
+ WrapGemmlowp(rhs, &gemmlowp_rhs);
+ WrapGemmlowpMutable(dst, &gemmlowp_dst);
+
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
+ quantize_down_stage.result_offset_after_shift = dst->zero_point;
+ quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint;
+ quantize_down_stage.result_exponent = spec.multiplier_exponent;
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
+ gemmlowp::VectorShape::Col>
+ quantize_down_stage_pc;
+ quantize_down_stage_pc.result_offset_after_shift = dst->zero_point;
+ using ColVectorMap =
+ gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>;
+ quantize_down_stage_pc.result_fixedpoint_multiplier =
+ ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows);
+ quantize_down_stage_pc.result_exponent =
+ ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows);
+
+ gemmlowp::OutputStageClamp clamp_stage;
+ clamp_stage.min = spec.clamp_min;
+ clamp_stage.max = spec.clamp_max;
+ using OutputStageSaturatingCast = typename std::conditional<
+ std::is_same<DstScalar, std::uint8_t>::value,
+ gemmlowp::OutputStageSaturatingCastToUint8,
+ gemmlowp::OutputStageSaturatingCastToInt16>::type;
+ OutputStageSaturatingCast saturating_cast_stage;
+
+ GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads
+ : 1);
+ if (spec.bias) {
+ using ColVectorMap =
+ gemmlowp::VectorMap<const std::int32_t, gemmlowp::VectorShape::Col>;
+ gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_add_stage;
+ bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows);
+#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE
+ if (spec.multiplier_exponent_perchannel) {
+ const auto& output_pipeline =
+ std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage,
+ saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point, -rhs.zero_point, output_pipeline);
+ } else // NOLINT[readability/braces]
+#endif
+ {
+ const auto& output_pipeline =
+ std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage,
+ saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point, -rhs.zero_point, output_pipeline);
+ }
+ } else {
+#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE
+ if (spec.multiplier_exponent_perchannel) {
+ const auto& output_pipeline = std::make_tuple(
+ quantize_down_stage_pc, clamp_stage, saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point, -rhs.zero_point, output_pipeline);
+ } else // NOLINT[readability/braces]
+#endif
+ {
+ const auto& output_pipeline = std::make_tuple(
+ quantize_down_stage, clamp_stage, saturating_cast_stage);
+ gemmlowp::GemmWithOutputPipeline<
+ LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
+ -lhs.zero_point, -rhs.zero_point, output_pipeline);
+ }
+ }
+}
+
+inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) {
+ return (LhsOrder == Order::kRowMajor ? 4 : 0) +
+ (RhsOrder == Order::kRowMajor ? 2 : 0) +
+ (DstOrder == Order::kRowMajor ? 1 : 0);
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+void EvalGemmlowp(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, int max_num_threads,
+ Matrix<DstScalar>* dst) {
+ int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order);
+ switch (index) {
+#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalGemmlowp<LHS, RHS, DST>(lhs, rhs, spec, max_num_threads, dst);
+#define EVALGEMMLOWP_CASE2(LHS, RHS) \
+ EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALGEMMLOWP_CASE1(LHS) \
+ EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \
+ EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor)
+
+ EVALGEMMLOWP_CASE1(Order::kColMajor)
+ EVALGEMMLOWP_CASE1(Order::kRowMajor)
+
+#undef EVALGEMMLOWP_CASE1
+#undef EVALGEMMLOWP_CASE2
+#undef EVALGEMMLOWP_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <Order tOrder>
+struct EigenOrder {};
+
+template <>
+struct EigenOrder<Order::kColMajor> {
+ static constexpr int kValue = Eigen::ColMajor;
+};
+
+template <>
+struct EigenOrder<Order::kRowMajor> {
+ static constexpr int kValue = Eigen::RowMajor;
+};
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename LhsScalar,
+ typename RhsScalar, typename DstScalar, typename Spec>
+void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, int max_num_threads, Matrix<DstScalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point, 0);
+ RUY_CHECK_EQ(rhs.zero_point, 0);
+ RUY_CHECK_EQ(dst->zero_point, 0);
+ RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0);
+ RUY_CHECK_EQ(spec.multiplier_exponent, 0);
+
+ static constexpr int kEigenLhsOrder = EigenOrder<LhsOrder>::kValue;
+ static constexpr int kEigenRhsOrder = EigenOrder<RhsOrder>::kValue;
+ static constexpr int kEigenDstOrder = EigenOrder<DstOrder>::kValue;
+
+ using EigenLhsType = typename Eigen::Matrix<LhsScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenLhsOrder>::
+ template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenRhsType = typename Eigen::Matrix<RhsScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenRhsOrder>::
+ template StridedConstMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenDstType = typename Eigen::Matrix<DstScalar, Eigen::Dynamic,
+ Eigen::Dynamic, kEigenDstOrder>::
+ template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenBiasType =
+ typename Eigen::Matrix<DstScalar, Eigen::Dynamic, 1>::ConstMapType;
+
+ EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols,
+ Eigen::OuterStride<Eigen::Dynamic>(lhs.layout.stride));
+ EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols,
+ Eigen::OuterStride<Eigen::Dynamic>(rhs.layout.stride));
+ EigenDstType eigen_dst(
+ dst->data.get(), dst->layout.rows, dst->layout.cols,
+ Eigen::OuterStride<Eigen::Dynamic>(dst->layout.stride));
+ Eigen::setNbThreads(max_num_threads ? max_num_threads : 1);
+
+ if (spec.bias) {
+ EigenBiasType eigen_bias(spec.bias, dst->layout.rows);
+ if (spec.clamp_max == std::numeric_limits<DstScalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<DstScalar>::infinity()) {
+ eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias;
+ } else {
+ eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias)
+ .cwiseMin(spec.clamp_max)
+ .cwiseMax(spec.clamp_min);
+ }
+ } else {
+ if (spec.clamp_max == std::numeric_limits<DstScalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<DstScalar>::infinity()) {
+ eigen_dst.noalias() = eigen_lhs * eigen_rhs;
+ } else {
+ eigen_dst.noalias() = (eigen_lhs * eigen_rhs)
+ .cwiseMin(spec.clamp_max)
+ .cwiseMax(spec.clamp_min);
+ }
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename DstScalar,
+ typename Spec>
+void EvalEigen(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+ const Spec& spec, int max_num_threads, Matrix<DstScalar>* dst) {
+ int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order);
+ switch (index) {
+#define EVALEIGEN_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalEigen<LHS, RHS, DST>(lhs, rhs, spec, max_num_threads, dst);
+#define EVALEIGEN_CASE2(LHS, RHS) \
+ EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALEIGEN_CASE1(LHS) \
+ EVALEIGEN_CASE2(LHS, Order::kColMajor) \
+ EVALEIGEN_CASE2(LHS, Order::kRowMajor)
+
+ EVALEIGEN_CASE1(Order::kColMajor)
+ EVALEIGEN_CASE1(Order::kRowMajor)
+
+#undef EVALEIGEN_CASE1
+#undef EVALEIGEN_CASE2
+#undef EVALEIGEN_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <Order LhsOrder, Order RhsOrder, Order DstOrder, typename Scalar,
+ typename Spec>
+void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const Spec& spec, int max_num_threads,
+ Matrix<Scalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point, 0);
+ RUY_CHECK_EQ(rhs.zero_point, 0);
+ RUY_CHECK_EQ(dst->zero_point, 0);
+ RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0);
+ RUY_CHECK_EQ(spec.multiplier_exponent, 0);
+
+ // Eigen::TensorMap only supports packed layouts
+ RUY_CHECK(IsPacked(lhs.layout));
+ RUY_CHECK(IsPacked(rhs.layout));
+ RUY_CHECK(IsPacked(dst->layout));
+
+ using TensorLhsType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>;
+ using TensorRhsType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 2, Eigen::ColMajor>>;
+ using TensorDstType =
+ Eigen::TensorMap<Eigen::Tensor<Scalar, 2, Eigen::ColMajor>>;
+ using TensorBiasType =
+ Eigen::TensorMap<Eigen::Tensor<const Scalar, 1, Eigen::ColMajor>>;
+
+ const bool tr = DstOrder == Order::kRowMajor;
+ const auto& contract_lhs = tr ? rhs : lhs;
+ const auto& contract_rhs = tr ? lhs : rhs;
+
+ TensorLhsType tensor_lhs(
+ contract_lhs.data.get(),
+ LhsOrder == Order::kColMajor ? contract_lhs.layout.rows
+ : contract_lhs.layout.cols,
+ LhsOrder == Order::kColMajor ? contract_lhs.layout.cols
+ : contract_lhs.layout.rows);
+ TensorRhsType tensor_rhs(
+ contract_rhs.data.get(),
+ RhsOrder == Order::kColMajor ? contract_rhs.layout.rows
+ : contract_rhs.layout.cols,
+ RhsOrder == Order::kColMajor ? contract_rhs.layout.cols
+ : contract_rhs.layout.rows);
+ TensorDstType tensor_dst(
+ dst->data.get(),
+ DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols,
+ DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows);
+ using DimPair =
+ typename Eigen::Tensor<Scalar, 1, 0, Eigen::Index>::DimensionPair;
+ Eigen::array<DimPair, 1> contract_dims(
+ {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0,
+ (RhsOrder == Order::kColMajor) ? 0 : 1)});
+ Eigen::array<int, 2> shuffle(DstOrder == Order::kColMajor ? 0 : 1,
+ DstOrder == Order::kColMajor ? 1 : 0);
+ static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1);
+ static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
+ if (spec.bias) {
+ TensorBiasType tensor_bias(spec.bias, dst->layout.rows);
+ Eigen::array<int, 2> bias_2d_shape(tr ? 1 : dst->layout.rows,
+ tr ? dst->layout.rows : 1);
+ Eigen::array<int, 2> bcast(tr ? dst->layout.cols : 1,
+ tr ? 1 : dst->layout.cols);
+ if (spec.clamp_max == std::numeric_limits<Scalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<Scalar>::infinity()) {
+ tensor_dst.device(device) =
+ tensor_lhs.contract(tensor_rhs, contract_dims);
+ } else {
+ tensor_dst.device(device) =
+ (tensor_lhs.contract(tensor_rhs, contract_dims) +
+ tensor_bias.reshape(bias_2d_shape).broadcast(bcast))
+ .cwiseMin(spec.clamp_max)
+ .cwiseMax(spec.clamp_min);
+ }
+ } else {
+ if (spec.clamp_max == std::numeric_limits<Scalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<Scalar>::infinity()) {
+ tensor_dst.device(device) =
+ tensor_lhs.contract(tensor_rhs, contract_dims);
+ } else {
+ tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims)
+ .cwiseMin(spec.clamp_max)
+ .cwiseMax(spec.clamp_min);
+ }
+ }
+}
+
+template <typename Scalar, typename Spec>
+void EvalEigenTensor(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const Spec& spec, int max_num_threads,
+ Matrix<Scalar>* dst) {
+ int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order);
+ switch (index) {
+#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \
+ case Mash(LHS, RHS, DST): \
+ return EvalEigenTensor<LHS, RHS, DST>(lhs, rhs, spec, max_num_threads, dst);
+#define EVALEIGENTENSOR_CASE2(LHS, RHS) \
+ EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \
+ EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor)
+#define EVALEIGENTENSOR_CASE1(LHS) \
+ EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \
+ EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor)
+
+ EVALEIGENTENSOR_CASE1(Order::kColMajor)
+ EVALEIGENTENSOR_CASE1(Order::kRowMajor)
+
+#undef EVALEIGENTENSOR_CASE1
+#undef EVALEIGENTENSOR_CASE2
+#undef EVALEIGENTENSOR_CASE3
+
+ default:
+ RUY_CHECK(false);
+ }
+}
+
+template <typename Scalar>
+struct GenericBlasGemm {};
+
+template <>
+struct GenericBlasGemm<lapack::doublereal> {
+ static void Run(char* transa, char* transb, lapack::integer* m,
+ lapack::integer* n, lapack::integer* k,
+ lapack::doublereal* alpha, lapack::doublereal* a,
+ lapack::integer* lda, lapack::doublereal* b,
+ lapack::integer* ldb, lapack::doublereal* beta,
+ lapack::doublereal* c, lapack::integer* ldc) {
+ dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ }
+};
+
+template <>
+struct GenericBlasGemm<lapack::real> {
+ static void Run(char* transa, char* transb, lapack::integer* m,
+ lapack::integer* n, lapack::integer* k, lapack::real* alpha,
+ lapack::real* a, lapack::integer* lda, lapack::real* b,
+ lapack::integer* ldb, lapack::real* beta, lapack::real* c,
+ lapack::integer* ldc) {
+ sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ }
+};
+
+template <typename Scalar, typename Spec>
+void EvalOpenBlas(const Matrix<Scalar>& lhs, const Matrix<Scalar>& rhs,
+ const Spec& spec, int max_num_threads, Matrix<Scalar>* dst) {
+ RUY_CHECK_EQ(lhs.zero_point, 0);
+ RUY_CHECK_EQ(rhs.zero_point, 0);
+ RUY_CHECK_EQ(dst->zero_point, 0);
+ RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0);
+ RUY_CHECK_EQ(spec.multiplier_exponent, 0);
+
+ Matrix<Scalar> gemm_lhs;
+ Matrix<Scalar> gemm_rhs;
+ Matrix<Scalar> gemm_dst;
+ gemm_dst = *dst;
+
+ // Use Transpose to reduce to the all-column-major case.
+ // Notice that ruy::Matrix merely holds a pointer, does not own data,
+ // so Transpose is cheap -- no actual matrix data is being transposed here.
+ if (dst->layout.order == Order::kColMajor) {
+ gemm_lhs = lhs;
+ gemm_rhs = rhs;
+ } else {
+ gemm_lhs = rhs;
+ gemm_rhs = lhs;
+ Transpose(&gemm_lhs);
+ Transpose(&gemm_rhs);
+ Transpose(&gemm_dst);
+ }
+ bool transposed_lhs = false;
+ bool transposed_rhs = false;
+
+ if (gemm_lhs.layout.order == Order::kRowMajor) {
+ Transpose(&gemm_lhs);
+ transposed_lhs = true;
+ }
+ if (gemm_rhs.layout.order == Order::kRowMajor) {
+ Transpose(&gemm_rhs);
+ transposed_rhs = true;
+ }
+
+ RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor);
+ RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor);
+ RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor);
+
+ char transa = transposed_lhs ? 'T' : 'N';
+ char transb = transposed_rhs ? 'T' : 'N';
+ int m = gemm_lhs.layout.rows;
+ int n = gemm_rhs.layout.cols;
+ int k = gemm_lhs.layout.cols;
+ float alpha = 1;
+ Scalar* a = gemm_lhs.data.get();
+ int lda = gemm_lhs.layout.stride;
+ Scalar* b = gemm_rhs.data.get();
+ int ldb = gemm_rhs.layout.stride;
+ float beta = 0;
+ Scalar* c = gemm_dst.data.get();
+ int ldc = gemm_dst.layout.stride;
+ GenericBlasGemm<Scalar>::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b,
+ &ldb, &beta, c, &ldc);
+
+ // BLAS does not allow us to express the bias-addition and clamping, so
+ // we use Eigen for that.
+
+ using EigenDstType =
+ typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>::
+ template StridedMapType<Eigen::OuterStride<Eigen::Dynamic>>::type;
+ using EigenBiasType =
+ typename Eigen::Matrix<Scalar, Eigen::Dynamic, 1>::ConstMapType;
+
+ EigenDstType eigen_dst(
+ gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols,
+ Eigen::OuterStride<Eigen::Dynamic>(gemm_dst.layout.stride));
+ Eigen::setNbThreads(max_num_threads ? max_num_threads : 1);
+
+ if (spec.bias) {
+ EigenBiasType eigen_bias(spec.bias, dst->layout.rows);
+ if (spec.clamp_max == std::numeric_limits<Scalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<Scalar>::infinity()) {
+ eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias;
+ } else {
+ eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias)
+ .cwiseMin(spec.clamp_max)
+ .cwiseMax(spec.clamp_min);
+ }
+ } else {
+ if (spec.clamp_max == std::numeric_limits<Scalar>::infinity() &&
+ spec.clamp_min == -std::numeric_limits<Scalar>::infinity()) {
+ } else {
+ eigen_dst.noalias() =
+ eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min);
+ }
+ }
+}
+
+template <typename TestSetType>
+struct SupportsGemmlowp {
+ static constexpr bool kValue =
+ std::is_same<typename TestSetType::LhsScalar, std::uint8_t>::value &&
+ std::is_same<typename TestSetType::RhsScalar, std::uint8_t>::value;
+};
+
+template <typename TestSetType>
+struct UsesSingleScalarType {
+ static constexpr bool kValue =
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::LhsScalar>::value &&
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::RhsScalar>::value &&
+ std::is_same<typename TestSetType::DstScalar,
+ typename TestSetType::AccumScalar>::value;
+};
+
+template <typename TestSetType,
+ bool IsFloatingPoint =
+ std::is_floating_point<typename TestSetType::AccumScalar>::value,
+ bool EnableGemmlowp = SupportsGemmlowp<TestSetType>::kValue,
+ bool SingleScalarType = UsesSingleScalarType<TestSetType>::kValue>
+struct EvalExternalPathImpl {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType*, TestResult<DstScalar>*) { RUY_CHECK(false); }
+};
+
+template <typename TestSetType>
+struct EvalExternalPathImpl<TestSetType, true, false, true> {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) {
+ if (test_result->external_path == ExternalPath::kEigen) {
+ EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec,
+ test_set->max_num_threads, &test_result->storage_matrix.matrix);
+ } else if (test_result->external_path == ExternalPath::kEigenTensor) {
+ EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix,
+ test_set->spec, test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else if (test_result->external_path == ExternalPath::kOpenBlas) {
+ EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec,
+ test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else {
+ RUY_CHECK(false);
+ }
+ }
+};
+
+template <typename TestSetType, bool SingleScalarType>
+struct EvalExternalPathImpl<TestSetType, false, true, SingleScalarType> {
+ using DstScalar = typename TestSetType::DstScalar;
+ static void Run(TestSetType* test_set, TestResult<DstScalar>* test_result) {
+ if (test_result->external_path == ExternalPath::kGemmlowp) {
+ EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec,
+ test_set->max_num_threads,
+ &test_result->storage_matrix.matrix);
+ } else {
+ RUY_CHECK(false);
+ }
+ }
+};
+
+template <typename TestSetType>
+void EvalExternalPath(
+ TestSetType* test_set,
+ TestResult<typename TestSetType::DstScalar>* test_result) {
+ EvalExternalPathImpl<TestSetType>::Run(test_set, test_result);
+}
+
+#endif // RUY_TEST_EXTERNAL_PATHS
+
+template <typename Scalar>
+bool Agree(const Matrix<Scalar>& matrix1, const Matrix<Scalar>& matrix2,
+ int depth) {
+ RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows);
+ RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols);
+ RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point);
+ const int size = matrix1.layout.rows * matrix1.layout.cols;
+ double tolerated_max_diff = 0;
+ double tolerated_mean_diff = 0;
+ if (std::is_floating_point<Scalar>::value) {
+ // TODO: replace hardcoded 100 by something more sensible, probably
+ // roughly sqrt(depth) based on central limit theorem.
+ double max_abs_val = 0;
+ for (int row = 0; row < matrix1.layout.rows; row++) {
+ for (int col = 0; col < matrix1.layout.cols; col++) {
+ max_abs_val =
+ std::max(max_abs_val,
+ std::abs(static_cast<double>(Element(matrix1, row, col))));
+ max_abs_val =
+ std::max(max_abs_val,
+ std::abs(static_cast<double>(Element(matrix2, row, col))));
+ }
+ }
+ tolerated_max_diff = max_abs_val * std::numeric_limits<Scalar>::epsilon() *
+ 64 * std::sqrt(static_cast<float>(depth));
+ tolerated_mean_diff = tolerated_max_diff / std::sqrt(size);
+ } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) {
+ tolerated_max_diff = 1;
+ // totally empirical
+ tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2));
+ }
+ double sum_diff = 0;
+ for (int row = 0; row < matrix1.layout.rows; row++) {
+ for (int col = 0; col < matrix1.layout.cols; col++) {
+ double elem1 = Element(matrix1, row, col);
+ double elem2 = Element(matrix2, row, col);
+ double diff = elem1 - elem2;
+
+ sum_diff += diff;
+ // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is
+ // NaN.
+ if (!(std::abs(diff) <= tolerated_max_diff)) {
+ return false;
+ }
+ }
+ }
+ double mean_diff = sum_diff / size;
+ if (std::abs(mean_diff) > tolerated_mean_diff) {
+ return false;
+ }
+ return true;
+}
+
+template <typename Scalar>
+bool Agree(const StorageMatrix<Scalar>& storage_matrix1,
+ const StorageMatrix<Scalar>& storage_matrix2, int depth) {
+ VerifyConsistentFields(storage_matrix1);
+ VerifyConsistentFields(storage_matrix2);
+ return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth);
+}
+
+template <typename Scalar>
+bool Agree(const TestResult<Scalar>& result1, const TestResult<Scalar>& result2,
+ int depth) {
+ return Agree(result1.storage_matrix, result2.storage_matrix, depth);
+}
+
+struct Stats {
+ double median;
+ double mean;
+ double min;
+ double max;
+};
+
+inline std::string StatsAsString(const Stats& stats) {
+ char buf[256];
+ snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)",
+ stats.median, stats.mean, stats.min, stats.max);
+ return std::string(buf);
+}
+
+template <typename Scalar>
+void GetMatrixStats(const Matrix<Scalar>& matrix, Stats* stats) {
+ double min = std::numeric_limits<double>::infinity();
+ double max = -std::numeric_limits<double>::infinity();
+ double sum = 0;
+ std::vector<double> allvals;
+ for (int row = 0; row < matrix.layout.rows; row++) {
+ for (int col = 0; col < matrix.layout.cols; col++) {
+ double val = Element(matrix, row, col);
+ min = std::min(min, val);
+ max = std::max(max, val);
+ sum += val;
+ allvals.push_back(val);
+ }
+ }
+ std::sort(allvals.begin(), allvals.end());
+ stats->min = min;
+ stats->max = max;
+ stats->mean = sum / allvals.size();
+ stats->median = allvals[allvals.size() / 2];
+}
+
+struct ErrorAnalysis {
+ Stats stats_good;
+ Stats stats_bad;
+ // The below is to help document departure from bit exactness. It's probably
+ // not going to be relevant to floating-point.
+ std::set<int> error_rows;
+ std::set<int> error_cols;
+ int row_of_first_error = 0;
+ int col_of_first_error = 0;
+ double first_error_good_value = 0;
+ double first_error_bad_value = 0;
+};
+
+template <typename TestSetType>
+void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index,
+ ErrorAnalysis* error_analysis) {
+ const auto& good_matrix = test_set.results[0]->storage_matrix.matrix;
+ const auto& bad_matrix =
+ test_set.results[first_bad_result_index]->storage_matrix.matrix;
+ GetMatrixStats(good_matrix, &error_analysis->stats_good);
+ GetMatrixStats(bad_matrix, &error_analysis->stats_bad);
+ bool found_first_error = false;
+ for (int row = 0; row < good_matrix.layout.rows; row++) {
+ for (int col = 0; col < good_matrix.layout.cols; col++) {
+ if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) {
+ if (!found_first_error) {
+ found_first_error = true;
+ error_analysis->row_of_first_error = row;
+ error_analysis->col_of_first_error = col;
+ error_analysis->first_error_good_value =
+ Element(good_matrix, row, col);
+ error_analysis->first_error_bad_value = Element(bad_matrix, row, col);
+ }
+ error_analysis->error_rows.insert(row);
+ error_analysis->error_cols.insert(col);
+ }
+ }
+ }
+}
+
+template <typename TestSetType>
+void ComputeReasonableMultiplier(
+ const Matrix<typename TestSetType::LhsScalar>& lhs,
+ const Matrix<typename TestSetType::RhsScalar>& rhs, double* multiplier) {
+ using LhsScalar = typename TestSetType::LhsScalar;
+ using RhsScalar = typename TestSetType::RhsScalar;
+ using DstScalar = typename TestSetType::DstScalar;
+ if (std::is_floating_point<DstScalar>::value ||
+ std::is_same<DstScalar, std::int32_t>::value) {
+ *multiplier = 0;
+ return;
+ }
+ *multiplier = static_cast<double>(std::numeric_limits<DstScalar>::max()) /
+ (static_cast<double>(lhs.layout.cols) *
+ std::numeric_limits<LhsScalar>::max() *
+ std::numeric_limits<RhsScalar>::max());
+}
+
+inline void QuantizeMultiplier(double multiplier_double,
+ std::int32_t* multiplier_fixedpoint,
+ int* multiplier_exponent) {
+ RUY_CHECK_GT(multiplier_double, 0);
+ if (multiplier_double == 0.) {
+ *multiplier_fixedpoint = 0;
+ *multiplier_exponent = 0;
+ return;
+ }
+ const double q = std::frexp(multiplier_double, multiplier_exponent);
+ auto q_fixed = static_cast<std::int64_t>(std::round(q * (1ll << 31)));
+ RUY_CHECK_LE(q_fixed, (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ ++*multiplier_exponent;
+ }
+ RUY_CHECK_LE(q_fixed, std::numeric_limits<std::int32_t>::max());
+ *multiplier_fixedpoint = static_cast<std::int32_t>(q_fixed);
+}
+
+template <typename TestSetType>
+void SwitchMultiplierToPerChannel(TestSetType* test_set) {
+ test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows);
+ test_set->per_channel_multiplier_exponent.resize(test_set->rows);
+ for (int i = 0; i < test_set->rows; i++) {
+ // multipliers typically range in [2^30 ; 2^31 - 1].
+ // Values in [0, 2^30 - 1] are normally unused, but harmless.
+ // Thus a good way to randomize multipliers is to subtract from them
+ // a random value smaller than 2^30 but still significant compared to it.
+ std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint -
+ (global_random_engine()() % (1 << 26));
+ int nudged_exponent =
+ test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4);
+ test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier;
+ test_set->per_channel_multiplier_exponent[i] = nudged_exponent;
+ }
+ test_set->spec.multiplier_fixedpoint_perchannel =
+ test_set->per_channel_multiplier_fixedpoint.data();
+ test_set->spec.multiplier_exponent_perchannel =
+ test_set->per_channel_multiplier_exponent.data();
+ test_set->spec.multiplier_fixedpoint = 0;
+ test_set->spec.multiplier_exponent = 0;
+}
+
+template <
+ typename TestSetType,
+ bool IsApplicable =
+ std::is_same<typename TestSetType::AccumScalar, std::int32_t>::value &&
+ !std::is_same<typename TestSetType::DstScalar, std::int32_t>::value>
+struct MakeSpecMultiplierFieldsImpl {};
+
+template <typename TestSetType>
+struct MakeSpecMultiplierFieldsImpl<TestSetType, true> {
+ static void Run(TestSetType* test_set) {
+ double multiplier;
+ ComputeReasonableMultiplier<TestSetType>(test_set->lhs.matrix,
+ test_set->rhs.matrix, &multiplier);
+ QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint,
+ &test_set->spec.multiplier_exponent);
+ if (!test_set->benchmark) {
+ test_set->perchannel = global_random_engine()() & 1;
+ }
+ if (test_set->perchannel) {
+ SwitchMultiplierToPerChannel(test_set);
+ }
+ }
+};
+
+template <typename TestSetType>
+struct MakeSpecMultiplierFieldsImpl<TestSetType, false> {
+ static void Run(TestSetType* test_set) {
+ test_set->spec.multiplier_fixedpoint = 0;
+ test_set->spec.multiplier_exponent = 0;
+ }
+};
+
+template <typename Spec>
+void MakeSpecClampFields(Spec* spec) {
+ using AccumScalar = typename Spec::AccumScalar;
+ using DstScalar = typename Spec::DstScalar;
+
+ if (std::is_same<AccumScalar, std::int32_t>::value) {
+ // Returning raw accumulators, clamping is not supported.
+ spec->clamp_min = std::numeric_limits<DstScalar>::lowest();
+ spec->clamp_max = std::numeric_limits<DstScalar>::max();
+ return;
+ }
+
+ if (getenv("BENCHMARK_ONLY_MATMUL")) {
+ if (std::is_floating_point<DstScalar>::value) {
+ spec->clamp_min = -std::numeric_limits<DstScalar>::infinity();
+ spec->clamp_max = std::numeric_limits<DstScalar>::infinity();
+ } else {
+ spec->clamp_min = std::numeric_limits<DstScalar>::lowest();
+ spec->clamp_max = std::numeric_limits<DstScalar>::max();
+ }
+ return;
+ }
+
+ spec->clamp_min = std::numeric_limits<DstScalar>::lowest() + 1;
+ spec->clamp_max = std::numeric_limits<DstScalar>::max() - 1;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakeZeroPoints() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kInitial);
+ if (!benchmark && !use_specified_zero_points) {
+ MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point);
+ MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point);
+ // If destination is std::int32_t, no dst_zero_point is necessary.
+ if (std::is_same<DstScalar, std::int32_t>::value) {
+ dst_zero_point = 0;
+ } else {
+ MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point);
+ }
+ }
+ life_stage = LifeStage::kHasZeroPoints;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakeLhsRhs() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints);
+ MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style,
+ RandomRange::kOffCenterAvoidMinValue, &lhs);
+ MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style,
+ RandomRange::kGeneral, &rhs);
+ life_stage = LifeStage::kHasLhsRhs;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakeSpec() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs);
+
+ if (!getenv("BENCHMARK_ONLY_MATMUL") &&
+ (benchmark || (global_random_engine()() & 1))) {
+ MakeRandomVector(RandomRange::kBias, rows, &bias_data);
+ spec.bias = bias_data.data();
+ }
+ if (lhs.matrix.zero_point == std::numeric_limits<LhsScalar>::lowest() &&
+ rhs.matrix.zero_point == std::numeric_limits<RhsScalar>::lowest()) {
+ lhs.matrix.zero_point += 1;
+ }
+ MakeSpecMultiplierFieldsImpl<TestSet>::Run(this);
+ MakeSpecClampFields(&spec);
+ life_stage = LifeStage::kHasSpec;
+}
+
+inline int GetIntEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stoi(val);
+}
+
+inline float GetFloatEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stof(val);
+}
+
+inline int GetHexIntEnvVarOrZero(const char* name) {
+ const char* val = getenv(name);
+ if (!val) {
+ return 0;
+ }
+ return std::stoi(val, nullptr, 16);
+}
+
+inline bool GetBoolEnvVarOrFalse(const char* name) {
+ return static_cast<bool>(GetIntEnvVarOrZero(name));
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakeOtherParams() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec);
+ if (max_num_threads == 0) {
+ max_num_threads = GetIntEnvVarOrZero("THREADS");
+ }
+ life_stage = LifeStage::kHasOtherParams;
+}
+
+inline std::vector<Path> PathsBitfieldAsVector(Path paths_bitfield) {
+ std::vector<Path> result;
+ std::uint32_t remaining_paths = static_cast<std::uint32_t>(paths_bitfield);
+ std::uint32_t test_bit = 1;
+ while (remaining_paths) {
+ if (remaining_paths & test_bit) {
+ result.push_back(static_cast<Path>(test_bit));
+ }
+ remaining_paths &= ~test_bit;
+ test_bit <<= 1;
+ }
+ return result;
+}
+
+inline std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) {
+ if (benchmark) {
+ return {Tuning::kAuto};
+ }
+#if RUY_PLATFORM(ARM)
+ if (path == Path::kNeon || path == Path::kNeonDotprod) {
+ return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto};
+ }
+#endif
+ return {Tuning::kAuto};
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakePrepackedMatrices() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths);
+
+ // Prepacked matrices are Path-dependent, so create them for each test result.
+ for (auto& result : results) {
+ // If this result uses an external path, then skip this entirely.
+ if (result->path == Path::kNone) {
+ continue;
+ }
+ // Pre-packing doesn't make sense for Path::kReference.
+ // TODO(silvasean): Make Path::kReference an ExternalPath?
+ if (result->path == Path::kReference) {
+ continue;
+ }
+
+ // Determine whether we should create/use prepacked matrices.
+ if (benchmark) {
+ // For benchmarking, do as requested.
+ result->use_prepacked_lhs = benchmark_prepack_lhs;
+ result->use_prepacked_rhs = benchmark_prepack_rhs;
+ } else {
+ // When testing, randomly pre-pack sometimes. But don't do it too often.
+ result->use_prepacked_lhs = (global_random_engine()() & 7) == 0;
+ result->use_prepacked_rhs = (global_random_engine()() & 7) == 0;
+ }
+
+ // Create the pre-packed matrices.
+ PrepackedMatrix* prepacked_lhs_ptr =
+ result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr;
+ PrepackedMatrix* prepacked_rhs_ptr =
+ result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr;
+ auto alloc_fn = [&result](std::size_t num_bytes) {
+ return result->allocator.AllocateBytes(num_bytes);
+ };
+ // Use a dst with a null data pointer to check that the pre-packing
+ // invocation doesn't write into it.
+ Matrix<DstScalar> null_data_dst = result->storage_matrix.matrix;
+ null_data_dst.data = nullptr;
+ GlobalContext().SetRuntimeEnabledPaths(result->path);
+ PrePackForMul<kAllPaths>(lhs.matrix, rhs.matrix, spec, &GlobalContext(),
+ &null_data_dst, prepacked_lhs_ptr,
+ prepacked_rhs_ptr, alloc_fn);
+ RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path);
+ }
+
+ life_stage = LifeStage::kHasPrepackedMatrices;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::MakeResultPaths() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams);
+
+ Path paths_bitfield = static_cast<Path>(GetHexIntEnvVarOrZero("PATHS"));
+
+ if (paths_bitfield == Path::kNone) {
+ // Use a dummy Context just to perform the resolution of specific runtime
+ // enabled paths.
+ Context context;
+ paths_bitfield = context.GetRuntimeEnabledPaths();
+ }
+
+ // Trim bits that don't correspond to a compiled path,
+ // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all
+ // those bits exist as actual paths.
+ paths_bitfield = paths_bitfield & kAllPaths;
+ RUY_CHECK_NE(paths_bitfield, Path::kNone);
+ paths = PathsBitfieldAsVector(paths_bitfield);
+
+#ifdef RUY_TEST_EXTERNAL_PATHS
+
+ using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
+
+ if (!GetBoolEnvVarOrFalse("NOEXT")) {
+ if (SupportsGemmlowp<TestSetType>::kValue) {
+#ifdef GEMMLOWP_SSE4
+ const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel;
+#else
+ const bool gemmlowp_supported = true;
+#endif
+ if (gemmlowp_supported) {
+ external_paths.push_back(ExternalPath::kGemmlowp);
+ }
+ }
+ if (UsesSingleScalarType<TestSetType>::kValue &&
+ std::is_floating_point<AccumScalar>::value) {
+ external_paths.push_back(ExternalPath::kEigen);
+ if (layout_style == LayoutStyle::kPackedLinear) {
+ external_paths.push_back(ExternalPath::kEigenTensor);
+ }
+// We link against a generic BLAS target that only maps to OpenBLAS on specific
+// architectures.
+#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64)
+ // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded
+ // and multi-threaded benchmark results.
+ if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) {
+ external_paths.push_back(ExternalPath::kOpenBlas);
+ }
+#endif
+ }
+ }
+
+#endif // RUY_TEST_EXTERNAL_PATHS
+
+ for (Path path : paths) {
+ for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) {
+ results.emplace_back(new TestResultType);
+ TestResultType& result = *results.back();
+ result.path = path;
+ result.tuning = tuning;
+ MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
+ RandomRange::kGeneral, &result.storage_matrix);
+ }
+ }
+
+ for (ExternalPath external_path : external_paths) {
+ results.emplace_back(new TestResultType);
+ TestResultType& result = *results.back();
+ result.external_path = external_path;
+ MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
+ RandomRange::kGeneral, &result.storage_matrix);
+ }
+
+ life_stage = LifeStage::kHasResultPaths;
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::EvalResult(
+ TestResult<typename SpecType::DstScalar>* result) {
+ RUY_CHECK(result->path != Path::kNone ||
+ result->external_path != ExternalPath::kNone);
+ if (result->path != Path::kNone) {
+ EvalRuy(result);
+ } else {
+#ifdef RUY_TEST_EXTERNAL_PATHS
+ using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
+ EvalExternalPath(this, result);
+#endif
+ }
+ const std::string& pathname = PathName(*result);
+ if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) ==
+ CoveredPaths()->end()) {
+ CoveredPaths()->push_back(pathname);
+ }
+}
+
+using f32 = float;
+using f64 = double;
+using u8 = std::uint8_t;
+using i8 = std::int8_t;
+using u16 = std::uint16_t;
+using i16 = std::int16_t;
+using u32 = std::uint32_t;
+using i32 = std::int32_t;
+using u64 = std::uint64_t;
+using i64 = std::int64_t;
+
+template <typename Scalar>
+const char* TypeName() {
+ return nullptr;
+}
+
+#define RUY_TYPENAME(TYPE) \
+ template <> \
+ const char* TypeName<TYPE>() { \
+ return #TYPE; \
+ }
+
+RUY_TYPENAME(f32)
+RUY_TYPENAME(f64)
+RUY_TYPENAME(u8)
+RUY_TYPENAME(i8)
+RUY_TYPENAME(u16)
+RUY_TYPENAME(i16)
+RUY_TYPENAME(u32)
+RUY_TYPENAME(i32)
+RUY_TYPENAME(u64)
+RUY_TYPENAME(i64)
+
+#undef RUY_TYPENAME
+
+template <typename Scalar>
+const char* SymmetryName(const Matrix<Scalar>& matrix) {
+ if (matrix.zero_point == SymmetricZeroPoint<Scalar>()) {
+ return "symm";
+ } else {
+ return "asymm";
+ }
+}
+
+template <typename Scalar>
+int StorageSize(const Matrix<Scalar>& matrix) {
+ return sizeof(Scalar) * FlatSize(matrix.layout);
+}
+
+// Helper that replicates a buffer and gives out pointers to the replicas.
+// This is useful when one wants to traverse data so that it is cold in cache.
+// By having a sufficiently large value of num_repeats, one can ensure that the
+// working set covered by the replicas is greater than the cache size.
+template <typename T>
+class RepeatedBuffer {
+ public:
+ RepeatedBuffer() = default;
+ void Init(const T* elems, std::size_t num_elems, int num_repeats) {
+ buffers_.clear();
+ allocator_.FreeAll();
+ for (int i = 0; i < num_repeats; i++) {
+ T* p;
+ allocator_.Allocate(num_elems, &p);
+ memcpy(p, elems, num_elems * sizeof(T));
+ buffers_.push_back(p);
+ }
+ }
+ T* Next() {
+ T* ret = buffers_[current_];
+ current_ = (current_ + 1) % buffers_.size();
+ return ret;
+ }
+
+ private:
+ Allocator allocator_;
+ std::vector<T*> buffers_;
+ int current_ = 0;
+};
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
+ TestResult<typename SpecType::DstScalar>* result) {
+ using DstScalar = typename SpecType::DstScalar;
+
+ const bool cold = getenv("RUY_BENCHMARK_COLD");
+ LhsScalar* orig_lhs_data = lhs.matrix.data.get();
+ RhsScalar* orig_rhs_data = rhs.matrix.data.get();
+ DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get();
+ void* orig_prepacked_lhs_data = result->prepacked_lhs.data;
+ void* orig_prepacked_rhs_data = result->prepacked_rhs.data;
+
+ int num_matmul_sets = 0;
+
+ RepeatedBuffer<LhsScalar> cold_lhs;
+ RepeatedBuffer<RhsScalar> cold_rhs;
+ RepeatedBuffer<DstScalar> cold_dst;
+ RepeatedBuffer<char> cold_prepacked_lhs;
+ RepeatedBuffer<char> cold_prepacked_rhs;
+
+ if (cold) {
+ const int kWorkingSetSize = 100 << 20;
+ const int each_matmul_set_size = StorageSize(lhs.matrix) +
+ StorageSize(rhs.matrix) +
+ StorageSize(result->storage_matrix.matrix);
+ num_matmul_sets =
+ (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size;
+
+ cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout),
+ num_matmul_sets);
+ cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout),
+ num_matmul_sets);
+ cold_dst.Init(result->storage_matrix.matrix.data.get(),
+ FlatSize(result->storage_matrix.matrix.layout),
+ num_matmul_sets);
+ if (benchmark_prepack_lhs) {
+ cold_prepacked_lhs.Init(static_cast<char*>(result->prepacked_lhs.data),
+ result->prepacked_lhs.data_size, num_matmul_sets);
+ }
+ if (benchmark_prepack_rhs) {
+ cold_prepacked_rhs.Init(static_cast<char*>(result->prepacked_rhs.data),
+ result->prepacked_rhs.data_size, num_matmul_sets);
+ }
+ }
+ const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU");
+ int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS");
+ if (!repeats) {
+ repeats = 4;
+ }
+ float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS");
+ if (!benchmark_min_secs) {
+ benchmark_min_secs = 0.5;
+ }
+#ifdef RUY_PROFILER
+ {
+ const char* lhstype = TypeName<LhsScalar>();
+ const char* lhssymm = SymmetryName(lhs.matrix);
+ const char* rhstype = TypeName<RhsScalar>();
+ const char* rhssymm = SymmetryName(rhs.matrix);
+
+ printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n",
+ PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm,
+ rhstype, rhssymm);
+ ruy::profiler::ScopeProfile profile;
+#endif
+
+ float latency = std::numeric_limits<float>::infinity();
+ float l1_refill_rate = std::numeric_limits<float>::infinity();
+ float l2_refill_rate = std::numeric_limits<float>::infinity();
+ float l3_refill_rate = std::numeric_limits<float>::infinity();
+ float l1tlb_refill_rate = std::numeric_limits<float>::infinity();
+ float l2tlb_refill_rate = std::numeric_limits<float>::infinity();
+ float mispred_rate = std::numeric_limits<float>::infinity();
+ float frontend_stall_rate = std::numeric_limits<float>::infinity();
+ float backend_stall_rate = std::numeric_limits<float>::infinity();
+
+ for (int repeat = 0; repeat < repeats; repeat++) {
+ auto& pmu_events = GlobalPmuEvents();
+ if (record_pmu) {
+ pmu_events.StartRecording();
+ }
+ TimePoint time_start = Now();
+ TimePoint t = time_start;
+ int iters = 0;
+ int iters_at_a_time = 1;
+ while (ToFloatSeconds(t - time_start) < benchmark_min_secs) {
+ for (int i = 0; i < iters_at_a_time; i++) {
+ if (cold) {
+ lhs.matrix.data = cold_lhs.Next();
+ rhs.matrix.data = cold_rhs.Next();
+ result->storage_matrix.matrix.data = cold_dst.Next();
+ if (benchmark_prepack_lhs) {
+ result->prepacked_lhs.data = cold_prepacked_lhs.Next();
+ }
+ if (benchmark_prepack_rhs) {
+ result->prepacked_rhs.data = cold_prepacked_rhs.Next();
+ }
+ }
+ EvalResult(result);
+ iters++;
+ }
+ iters_at_a_time *= 2;
+ t = Now();
+ }
+ latency = std::min(
+ latency, static_cast<float>(ToFloatSeconds(t - time_start) / iters));
+ if (record_pmu) {
+ pmu_events.StopRecording();
+ const float normalization_factor =
+ 1.0f / (static_cast<float>(iters) * rows * cols * depth);
+ l1_refill_rate = std::min(
+ l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor);
+ l2_refill_rate = std::min(
+ l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor);
+ l3_refill_rate = std::min(
+ l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor);
+ l1tlb_refill_rate =
+ std::min(l1tlb_refill_rate,
+ pmu_events.L1TLBRefillCount() * normalization_factor);
+ l2tlb_refill_rate =
+ std::min(l2tlb_refill_rate,
+ pmu_events.L2TLBRefillCount() * normalization_factor);
+ mispred_rate =
+ std::min(mispred_rate, pmu_events.BranchMispredictionCount() *
+ normalization_factor);
+ frontend_stall_rate =
+ std::min(frontend_stall_rate,
+ pmu_events.FrontendStallCount() * normalization_factor);
+ backend_stall_rate =
+ std::min(backend_stall_rate,
+ pmu_events.BackendStallCount() * normalization_factor);
+ }
+ }
+ result->latency = latency;
+ if (record_pmu) {
+ result->l1_refill_rate = l1_refill_rate;
+ result->l2_refill_rate = l2_refill_rate;
+ result->l3_refill_rate = l3_refill_rate;
+ result->l1tlb_refill_rate = l1tlb_refill_rate;
+ result->l2tlb_refill_rate = l2tlb_refill_rate;
+ result->mispred_rate = mispred_rate;
+ result->frontend_stall_rate = frontend_stall_rate;
+ result->backend_stall_rate = backend_stall_rate;
+ }
+
+#ifdef RUY_PROFILER
+ }
+ fflush(stdout);
+#endif
+
+ if (cold) {
+ lhs.matrix.data = orig_lhs_data;
+ rhs.matrix.data = orig_rhs_data;
+ memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(),
+ StorageSize(result->storage_matrix.matrix));
+ result->storage_matrix.matrix.data = orig_dst_data;
+ result->prepacked_lhs.data = orig_prepacked_lhs_data;
+ result->prepacked_rhs.data = orig_prepacked_rhs_data;
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::Eval() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices);
+ for (auto& result : results) {
+ if (benchmark) {
+ Benchmark(result.get());
+ } else {
+ EvalResult(result.get());
+ }
+ }
+ life_stage = LifeStage::kEvaluated;
+}
+
+template <typename Scalar>
+std::string DumpRegion(const Matrix<Scalar>& matrix, int center_row,
+ int center_col) {
+ static constexpr int kRadius = 20;
+ int first_row = std::max(0, center_row - kRadius);
+ int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius);
+ int first_col = std::max(0, center_col - kRadius);
+ int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius);
+ std::ostringstream stream;
+ for (int row = first_row; row <= last_row; row++) {
+ for (int col = first_col; col <= last_col; col++) {
+ stream << static_cast<double>(Element(matrix, row, col)) << " ";
+ }
+ stream << "\n";
+ }
+ return stream.str();
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::VerifyTestResults() const {
+ const int depth = lhs.matrix.layout.cols;
+ for (int i = 0; i < results.size() - 1; i++) {
+ if (!Agree(*results[i], *results[i + 1], depth)) {
+ std::string paths_in_agreement;
+ paths_in_agreement.append(PathName(*results[0]));
+ for (int j = 1; j <= i; j++) {
+ paths_in_agreement.append(", ");
+ paths_in_agreement.append(PathName(*results[j]));
+ }
+ ErrorAnalysis error_analysis;
+ AnalyzeTestError(*this, i + 1, &error_analysis);
+ std::cerr << "Error: path (" << PathName(*results[i + 1])
+ << ") disagrees with the other paths (" << paths_in_agreement
+ << "), which agree with each other." << std::endl;
+ std::cerr << "Shape: rows = " << rows << ", cols = " << cols
+ << ", depth = " << depth << std::endl;
+ std::cerr << "Stats of the good result matrix: "
+ << StatsAsString(error_analysis.stats_good) << std::endl;
+ std::cerr << "Stats of the bad result matrix: "
+ << StatsAsString(error_analysis.stats_bad) << std::endl;
+ if (error_analysis.error_rows.size() < rows) {
+ std::cerr << "Rows containing errors: "
+ << Join(error_analysis.error_rows) << std::endl;
+ } else {
+ std::cerr << "Errors found in ALL rows." << std::endl;
+ }
+ if (error_analysis.error_cols.size() < cols) {
+ std::cerr << "Cols containing errors: "
+ << Join(error_analysis.error_cols) << std::endl;
+ } else {
+ std::cerr << "Errors found in ALL cols." << std::endl;
+ }
+ std::cerr << "The first error occurs at row "
+ << error_analysis.row_of_first_error << ", col "
+ << error_analysis.col_of_first_error << std::endl;
+ std::cerr << "Good value: " << error_analysis.first_error_good_value
+ << std::endl;
+ std::cerr << "Bad value : " << error_analysis.first_error_bad_value
+ << std::endl;
+ std::cerr << "Region of Good result matrix around first error:\n\n"
+ << DumpRegion(results[0]->storage_matrix.matrix,
+ error_analysis.row_of_first_error,
+ error_analysis.col_of_first_error)
+ << std::endl;
+ std::cerr << "Region of Bad result matrix around first error:\n\n"
+ << DumpRegion(results[i + 1]->storage_matrix.matrix,
+ error_analysis.row_of_first_error,
+ error_analysis.col_of_first_error)
+ << std::endl;
+ RUY_CHECK(false);
+ }
+ }
+}
+
+template <typename LhsScalar, typename RhsScalar, typename SpecType>
+void TestSet<LhsScalar, RhsScalar, SpecType>::Verify() {
+ RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated);
+ if (expected_outcome == ExpectedOutcome::kSuccess) {
+ VerifyTestResults();
+ }
+ life_stage = LifeStage::kFinal;
+}
+
+template <typename TestSetType>
+void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = Order::kRowMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kPackedLinear;
+ test_set.expected_outcome = expected_outcome;
+ test_set.Run();
+}
+
+template <typename TestSetType>
+void TestRCC(int rows, int depth, int cols) {
+ TestRCC<TestSetType>(rows, depth, cols, ExpectedOutcome::kSuccess);
+}
+
+template <typename TestSetType>
+void TestNonRCC(int rows, int depth, int cols,
+ ExpectedOutcome expected_outcome) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = Order::kColMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kPackedLinear;
+ test_set.expected_outcome = expected_outcome;
+ test_set.Run();
+}
+
+template <typename TestSetType>
+void TestLinearAllOrders(int rows, int depth, int cols,
+ ExpectedOutcome expected_outcome) {
+ const std::vector<Order> orders{Order::kColMajor, Order::kRowMajor};
+
+ for (Order lhs_order : orders) {
+ for (Order rhs_order : orders) {
+ for (Order dst_order : orders) {
+ TestSetType test_set;
+ test_set.rows = rows;
+ test_set.depth = depth;
+ test_set.cols = cols;
+ test_set.lhs_order = lhs_order;
+ test_set.rhs_order = rhs_order;
+ test_set.dst_order = dst_order;
+ test_set.layout_style = LayoutStyle::kLinear;
+ test_set.expected_outcome = expected_outcome;
+ test_set.Run();
+ }
+ }
+ }
+}
+
+template <typename TestSetType>
+void TestLinearAllOrders(int rows, int depth, int cols) {
+ TestLinearAllOrders<TestSetType>(rows, depth, cols,
+ ExpectedOutcome::kSuccess);
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_
diff --git a/ruy/test_fast.cc b/ruy/test_fast.cc
new file mode 100644
index 0000000..d1c1308
--- /dev/null
+++ b/ruy/test_fast.cc
@@ -0,0 +1,110 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test contains cheap test cases, completes in a few seconds.
+
+#include <vector>
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+
+using TestSetType =
+ TestSet<LhsScalar, RhsScalar, BasicSpec<AccumScalar, DstScalar>>;
+
+TEST(RuyTest, TestSquareMuls) {
+ const std::vector<int> sizes{
+ // small sizes
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ // multiplies of 16
+ 16,
+ 32,
+ 48,
+ 64,
+ // pot-minus-1 sizes
+ 15,
+ 31,
+ 63,
+ // pot-plus-1 sizes
+ 17,
+ 33,
+ 65,
+ };
+
+ for (int size : sizes) {
+ TestRCC<TestSetType>(size, size, size);
+ TestLinearAllOrders<TestSetType>(size, size, size);
+ }
+}
+
+TEST(RuyTest, TestMiscMuls) {
+ const int shapes[][3] = {
+ {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17},
+ {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76},
+ {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16},
+ {16, 81, 16}, {16, 95, 16}, {3, 128, 5}};
+ for (const auto& shape : shapes) {
+ TestLinearAllOrders<TestSetType>(shape[0], shape[1], shape[2]);
+ }
+}
+
+TEST(RuyTest, TestDeepMuls) {
+ // TODO(b/137649322): clarify what's the max allowed matrix size.
+ TestRCC<TestSetType>(1, 32767, 1);
+ TestLinearAllOrders<TestSetType>(5, 5001, 4);
+ TestLinearAllOrders<TestSetType>(9, 1025, 10);
+}
+
+TEST(RuyTest, TestShallowMuls) {
+ TestLinearAllOrders<TestSetType>(101, 1, 103);
+ TestLinearAllOrders<TestSetType>(71, 2, 53);
+ TestLinearAllOrders<TestSetType>(51, 3, 73);
+ TestLinearAllOrders<TestSetType>(51, 4, 43);
+}
+
+TEST(RuyTest, TestNarrowMuls) {
+ for (int width : {1, 2, 3, 4, 5, 8}) {
+ TestLinearAllOrders<TestSetType>(width, 12, 13);
+ TestLinearAllOrders<TestSetType>(15, 19, width);
+ TestLinearAllOrders<TestSetType>(width, 123, 137);
+ TestLinearAllOrders<TestSetType>(158, 119, width);
+ }
+}
+
+TEST(RuyTest, TestGEMV) {
+ for (int size = 1; size < 1024; size *= 2) {
+ for (int depth = 1; depth < 500; depth += 47) {
+ TestLinearAllOrders<TestSetType>(size, depth, 1);
+ }
+ }
+ TestLinearAllOrders<TestSetType>(5, 5001, 1);
+ TestLinearAllOrders<TestSetType>(8193, 17, 1);
+}
+
+} // namespace ruy
diff --git a/ruy/test_slow.cc b/ruy/test_slow.cc
new file mode 100644
index 0000000..9f0f218
--- /dev/null
+++ b/ruy/test_slow.cc
@@ -0,0 +1,71 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test contains more expensive test cases.
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+
+using TestSetType =
+ TestSet<LhsScalar, RhsScalar, BasicSpec<AccumScalar, DstScalar>>;
+
+TEST(RuyTest, TestBigNarrowMuls) {
+ for (int width : {1, 2, 3, 4, 5, 8}) {
+ TestRCC<TestSetType>(width, 401, 601);
+ TestRCC<TestSetType>(587, 443, width);
+ }
+ TestRCC<TestSetType>(7, 45984,
+ 5); // Large enough to trigger row-sum overflows.
+ TestRCC<TestSetType>(512, 256, 16);
+}
+
+TEST(RuyTest, TestBigShallowMuls) {
+ TestLinearAllOrders<TestSetType>(501, 1, 321);
+ TestLinearAllOrders<TestSetType>(301, 5, 403);
+ TestLinearAllOrders<TestSetType>(256, 32, 512);
+}
+
+TEST(RuyTest, TestBigMuls) {
+ TestRCC<TestSetType>(225, 303, 199);
+ TestLinearAllOrders<TestSetType>(256, 192, 128);
+}
+
+TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) {
+ // Important to test some power-of-two depths: that's when the
+ // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices
+ // strided, exposing bugs in kernels mixing up size and stride.
+ // Moreover, it's important that the test matrices be sufficiently wide
+ // that they will result in multiple blocks, exposing bugs in the
+ // computation of the base address of each block.
+ TestLinearAllOrders<TestSetType>(70, 1024, 80);
+ TestLinearAllOrders<TestSetType>(60, 2048, 70);
+ TestLinearAllOrders<TestSetType>(40, 4096, 50);
+}
+
+TEST(RuyTest, TestGEMV) {
+ for (int size = 1025; size <= 1409; size += 384) {
+ for (int depth = 350; depth < 500; depth += 47) {
+ TestLinearAllOrders<TestSetType>(size, depth, 1);
+ }
+ }
+}
+
+} // namespace ruy
diff --git a/ruy/test_special_specs.cc b/ruy/test_special_specs.cc
new file mode 100644
index 0000000..a621d0e
--- /dev/null
+++ b/ruy/test_special_specs.cc
@@ -0,0 +1,163 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This test covers non-basic specs.
+
+#include "ruy/test.h"
+
+namespace ruy {
+
+template <typename AccumScalar, typename DstScalar,
+ LoopStructure tLoopStructure>
+struct LoopStructureSpec : BasicSpec<AccumScalar, DstScalar> {
+ static constexpr LoopStructure kLoopStructure = tLoopStructure;
+};
+
+template <typename AccumScalar, typename DstScalar,
+ ZeroPointSupport tZeroPointSupport>
+struct ZeroPointSupportSpec : BasicSpec<AccumScalar, DstScalar> {
+ static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport;
+};
+
+template <typename AccumScalar, typename DstScalar>
+struct RCCSpec : BasicSpec<AccumScalar, DstScalar> {
+ static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC;
+};
+
+template <typename AccumScalar, typename DstScalar, typename LhsKernelLayout,
+ typename RhsKernelLayout>
+struct StandardCppKernelLayoutSpec : BasicSpec<AccumScalar, DstScalar> {
+ using StandardCppKernelLhsLayout = LhsKernelLayout;
+ using StandardCppKernelRhsLayout = RhsKernelLayout;
+ static int local_data_cache_size() { return 1; }
+ static int shared_data_cache_size() { return 1; }
+};
+
+using LhsScalar = RUY_TEST_LHSSCALAR;
+using RhsScalar = RUY_TEST_RHSSCALAR;
+using AccumScalar = RUY_TEST_ACCUMSCALAR;
+using DstScalar = RUY_TEST_DSTSCALAR;
+
+template <LoopStructure tLoopStructure>
+void TestLoopStructure() {
+ using SpecType = LoopStructureSpec<AccumScalar, DstScalar, tLoopStructure>;
+ using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
+ for (int size = 1; size < 10; size++) {
+ TestLinearAllOrders<TestSetType>(size, size, size);
+ }
+ TestLinearAllOrders<TestSetType>(3, 5, 78);
+ TestLinearAllOrders<TestSetType>(19, 91, 7);
+ TestLinearAllOrders<TestSetType>(71, 26, 44);
+ TestLinearAllOrders<TestSetType>(81, 93, 72);
+}
+
+TEST(TestSpecialSpecs, LoopStructure) {
+ static_assert(BasicSpec<std::uint8_t, std::int32_t>::kLoopStructure ==
+ LoopStructure::kAuto,
+ "");
+ static_assert(BasicSpec<float, float>::kLoopStructure == LoopStructure::kAuto,
+ "");
+ TestLoopStructure<LoopStructure::kSimple>();
+ TestLoopStructure<LoopStructure::kGeneral>();
+}
+
+template <ZeroPointSupport tZeroPointSupport>
+void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
+ DstScalar dst_zero_point,
+ ExpectedOutcome expected_outcome) {
+ using SpecType =
+ ZeroPointSupportSpec<AccumScalar, DstScalar, tZeroPointSupport>;
+ using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
+ TestSetType test_set;
+ test_set.rows = 11;
+ test_set.depth = 12;
+ test_set.cols = 13;
+ test_set.lhs_order = Order::kRowMajor;
+ test_set.rhs_order = Order::kColMajor;
+ test_set.dst_order = Order::kColMajor;
+ test_set.layout_style = LayoutStyle::kPackedLinear;
+ test_set.expected_outcome = expected_outcome;
+ test_set.lhs_zero_point = lhs_zero_point;
+ test_set.rhs_zero_point = rhs_zero_point;
+ test_set.dst_zero_point = dst_zero_point;
+ test_set.use_specified_zero_points = true;
+ test_set.Run();
+}
+
+TEST(TestSpecialSpecs, ZeroPointSupport) {
+ // Sanity check
+ RUY_CHECK_EQ(SymmetricZeroPoint<std::uint8_t>(), 128);
+ RUY_CHECK_EQ(SymmetricZeroPoint<std::int8_t>(), 0);
+
+ if (std::is_floating_point<LhsScalar>::value) {
+ return;
+ }
+
+ TestZeroPointSupport<ZeroPointSupport::kGeneral>(
+ SymmetricZeroPoint<LhsScalar>(), SymmetricZeroPoint<RhsScalar>(),
+ SymmetricZeroPoint<DstScalar>(), ExpectedOutcome::kSuccess);
+ TestZeroPointSupport<ZeroPointSupport::kGeneral>(
+ SymmetricZeroPoint<LhsScalar>() - 1, SymmetricZeroPoint<RhsScalar>(),
+ SymmetricZeroPoint<DstScalar>(), ExpectedOutcome::kSuccess);
+ TestZeroPointSupport<ZeroPointSupport::kSymmetric>(
+ SymmetricZeroPoint<LhsScalar>(), SymmetricZeroPoint<RhsScalar>(),
+ SymmetricZeroPoint<DstScalar>(), ExpectedOutcome::kSuccess);
+ TestZeroPointSupport<ZeroPointSupport::kSymmetric>(
+ SymmetricZeroPoint<LhsScalar>() + 1, SymmetricZeroPoint<RhsScalar>(),
+ SymmetricZeroPoint<DstScalar>(), ExpectedOutcome::kDeath);
+ TestZeroPointSupport<ZeroPointSupport::kSymmetric>(
+ SymmetricZeroPoint<LhsScalar>(), SymmetricZeroPoint<RhsScalar>() + 1,
+ SymmetricZeroPoint<DstScalar>(), ExpectedOutcome::kDeath);
+ TestZeroPointSupport<ZeroPointSupport::kSymmetric>(
+ SymmetricZeroPoint<LhsScalar>(), SymmetricZeroPoint<RhsScalar>(),
+ SymmetricZeroPoint<DstScalar>() - 1, ExpectedOutcome::kDeath);
+}
+
+TEST(TestSpecialSpecs, RCC) {
+ using RCCSpec = RCCSpec<AccumScalar, DstScalar>;
+ using RCCTestSet = TestSet<LhsScalar, RhsScalar, RCCSpec>;
+ TestRCC<RCCTestSet>(81, 93, 72);
+ TestNonRCC<RCCTestSet>(81, 93, 72, ExpectedOutcome::kDeath);
+}
+
+template <typename LhsKernelLayout, typename RhsKernelLayout>
+void TestStandardCppKernelLayout() {
+ using SpecType =
+ StandardCppKernelLayoutSpec<AccumScalar, DstScalar, LhsKernelLayout,
+ RhsKernelLayout>;
+ using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
+ for (int size = 1; size < 10; size++) {
+ TestLinearAllOrders<TestSetType>(size, size, size);
+ }
+ TestLinearAllOrders<TestSetType>(87, 34, 56);
+ TestLinearAllOrders<TestSetType>(123, 234, 78);
+}
+
+TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) {
+ TestStandardCppKernelLayout<FixedKernelLayout<Order::kColMajor, 1, 1>,
+ FixedKernelLayout<Order::kColMajor, 1, 1>>();
+}
+
+TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) {
+ TestStandardCppKernelLayout<FixedKernelLayout<Order::kRowMajor, 4, 4>,
+ FixedKernelLayout<Order::kRowMajor, 4, 4>>();
+}
+
+TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) {
+ TestStandardCppKernelLayout<FixedKernelLayout<Order::kColMajor, 1, 4>,
+ FixedKernelLayout<Order::kColMajor, 1, 8>>();
+}
+
+} // namespace ruy
diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc
new file mode 100644
index 0000000..d09bf1e
--- /dev/null
+++ b/ruy/thread_pool.cc
@@ -0,0 +1,200 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/thread_pool.h"
+
+#include <atomic>
+#include <chrono> // NOLINT(build/c++11)
+#include <condition_variable> // NOLINT(build/c++11)
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/check_macros.h"
+#include "ruy/wait.h"
+
+namespace ruy {
+
+// A worker thread.
+class Thread {
+ public:
+ enum class State {
+ Startup, // The initial state before the thread main loop runs.
+ Ready, // Is not working, has not yet received new work to do.
+ HasWork, // Has work to do.
+ ExitAsSoonAsPossible // Should exit at earliest convenience.
+ };
+
+ explicit Thread(BlockingCounter* counter_to_decrement_when_ready)
+ : task_(nullptr),
+ state_(State::Startup),
+ counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
+ thread_.reset(new std::thread(ThreadFunc, this));
+ }
+
+ ~Thread() {
+ ChangeState(State::ExitAsSoonAsPossible);
+ thread_->join();
+ }
+
+ // Changes State; may be called from either the worker thread
+ // or the master thread; however, not all state transitions are legal,
+ // which is guarded by assertions.
+ //
+ // The Task argument is to be used only with new_state==HasWork.
+ // It specifies the Task being handed to this Thread.
+ void ChangeState(State new_state, Task* task = nullptr) {
+ state_mutex_.lock();
+ State old_state = state_.load(std::memory_order_relaxed);
+ RUY_DCHECK_NE(old_state, new_state);
+ switch (old_state) {
+ case State::Startup:
+ RUY_DCHECK_EQ(new_state, State::Ready);
+ break;
+ case State::Ready:
+ RUY_DCHECK(new_state == State::HasWork ||
+ new_state == State::ExitAsSoonAsPossible);
+ break;
+ case State::HasWork:
+ RUY_DCHECK(new_state == State::Ready ||
+ new_state == State::ExitAsSoonAsPossible);
+ break;
+ default:
+ abort();
+ }
+ switch (new_state) {
+ case State::Ready:
+ if (task_) {
+ // Doing work is part of reverting to 'ready' state.
+ task_->Run();
+ task_ = nullptr;
+ }
+ break;
+ case State::HasWork:
+ RUY_DCHECK(!task_);
+ task_ = task;
+ break;
+ default:
+ break;
+ }
+ state_.store(new_state, std::memory_order_relaxed);
+ state_cond_.notify_all();
+ state_mutex_.unlock();
+ if (new_state == State::Ready) {
+ counter_to_decrement_when_ready_->DecrementCount();
+ }
+ }
+
+ static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
+
+ // Called by the master thead to give this thread work to do.
+ void StartWork(Task* task) { ChangeState(State::HasWork, task); }
+
+ private:
+ // Thread entry point.
+ void ThreadFuncImpl() {
+ ChangeState(State::Ready);
+
+ // Thread main loop
+ while (true) {
+ // In the 'Ready' state, we have nothing to do but to wait until
+ // we switch to another state.
+ const auto& condition = [this]() {
+ return state_.load(std::memory_order_acquire) != State::Ready;
+ };
+ Wait(condition, &state_cond_, &state_mutex_);
+
+ // Act on new state.
+ switch (state_.load(std::memory_order_acquire)) {
+ case State::HasWork:
+ // Got work to do! So do it, and then revert to 'Ready' state.
+ ChangeState(State::Ready);
+ break;
+ case State::ExitAsSoonAsPossible:
+ return;
+ default:
+ abort();
+ }
+ }
+ }
+
+ // The underlying thread.
+ std::unique_ptr<std::thread> thread_;
+
+ // The task to be worked on.
+ Task* task_;
+
+ // The condition variable and mutex guarding state changes.
+ std::condition_variable state_cond_;
+ std::mutex state_mutex_;
+
+ // The state enum tells if we're currently working, waiting for work, etc.
+ // Its concurrent accesses by the thread and main threads are guarded by
+ // state_mutex_, and can thus use memory_order_relaxed. This still needs
+ // to be a std::atomic because we use WaitForVariableChange.
+ std::atomic<State> state_;
+
+ // pointer to the master's thread BlockingCounter object, to notify the
+ // master thread of when this thread switches to the 'Ready' state.
+ BlockingCounter* const counter_to_decrement_when_ready_;
+};
+
+void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
+ RUY_DCHECK_GE(task_count, 1);
+
+ // Case of 1 thread: just run the single task on the current thread.
+ if (task_count == 1) {
+ (tasks + 0)->Run();
+ return;
+ }
+
+ // Task #0 will be run on the current thread.
+ CreateThreads(task_count - 1);
+ counter_to_decrement_when_ready_.Reset(task_count - 1);
+ for (int i = 1; i < task_count; i++) {
+ auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
+ threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
+ }
+
+ // Execute task #0 immediately on the current thread.
+ (tasks + 0)->Run();
+
+ // Wait for the threads submitted above to finish.
+ counter_to_decrement_when_ready_.Wait();
+}
+
+// Ensures that the pool has at least the given count of threads.
+// If any new thread has to be created, this function waits for it to
+// be ready.
+void ThreadPool::CreateThreads(int threads_count) {
+ if (threads_.size() >= threads_count) {
+ return;
+ }
+ counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
+ while (threads_.size() < threads_count) {
+ threads_.push_back(new Thread(&counter_to_decrement_when_ready_));
+ }
+ counter_to_decrement_when_ready_.Wait();
+}
+
+ThreadPool::~ThreadPool() {
+ for (auto w : threads_) {
+ delete w;
+ }
+}
+
+} // end namespace ruy
diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h
new file mode 100644
index 0000000..04c201c
--- /dev/null
+++ b/ruy/thread_pool.h
@@ -0,0 +1,102 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0
+// license.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_
+
+#include <vector>
+
+#include "ruy/blocking_counter.h"
+
+namespace ruy {
+
+// A workload for a thread.
+struct Task {
+ virtual ~Task() {}
+ virtual void Run() = 0;
+};
+
+class Thread;
+
+// A simple pool of threads, that only allows the very
+// specific parallelization pattern that we use here:
+// One thread, which we call the 'main thread', calls Execute, distributing
+// a Task each to N threads, being N-1 'worker threads' and the main thread
+// itself. After the main thread has completed its own Task, it waits for
+// the worker threads to have all completed. That is the only synchronization
+// performed by this ThreadPool.
+//
+// In particular, there is a naive 1:1 mapping of Tasks to threads.
+// This ThreadPool considers it outside of its own scope to try to work
+// with fewer threads than there are Tasks. The idea is that such N:M mappings
+// of tasks to threads can be implemented as a higher-level feature on top of
+// the present low-level 1:1 threadpool. For example, a user might have a
+// Task subclass referencing a shared atomic counter indexing into a vector of
+// finer-granularity subtasks. Different threads would then concurrently
+// increment this atomic counter, getting each their own subtasks to work on.
+// That approach is the one used in ruy's multi-thread matrix multiplication
+// implementation --- see ruy's TrMulTask.
+class ThreadPool {
+ public:
+ ThreadPool() {}
+
+ ~ThreadPool();
+
+ // Executes task_count tasks on task_count threads.
+ // Grows the threadpool as needed to have at least (task_count-1) threads.
+ // The 0-th task is run on the thread on which Execute is called: that
+ // is by definition what we call the "main thread". Synchronization of all
+ // threads is performed before this function returns.
+ //
+ // As explained in the class comment, there is a 1:1 mapping of tasks to
+ // threads. If you need something smarter than that, for instance if you
+ // want to run an unbounded number of tasks on a bounded number of threads,
+ // then you need something higher-level than this ThreadPool, that can
+ // be layered on top of it by appropriately subclassing Tasks.
+ //
+ // TaskType must be a subclass of ruy::Task. That is implicitly guarded by
+ // the static_cast in this inline implementation.
+ template <typename TaskType>
+ void Execute(int task_count, TaskType* tasks) {
+ ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks));
+ }
+
+ private:
+ // Ensures that the pool has at least the given count of threads.
+ // If any new thread has to be created, this function waits for it to
+ // be ready.
+ void CreateThreads(int threads_count);
+
+ // Non-templatized implementation of the public Execute method.
+ // See the inline implementation of Execute for how this is used.
+ void ExecuteImpl(int task_count, int stride, Task* tasks);
+
+ // copy construction disallowed
+ ThreadPool(const ThreadPool&) = delete;
+
+ // The threads in this pool. They are owned by the pool:
+ // the pool creates threads and destroys them in its destructor.
+ std::vector<Thread*> threads_;
+
+ // The BlockingCounter used to wait for the threads.
+ BlockingCounter counter_to_decrement_when_ready_;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_
diff --git a/ruy/time.h b/ruy/time.h
new file mode 100644
index 0000000..9dba75e
--- /dev/null
+++ b/ruy/time.h
@@ -0,0 +1,81 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_
+
+#include <chrono> // NOLINT(build/c++11)
+#include <cstdint> // IWYU pragma: keep
+#include <ratio> // NOLINT(build/c++11)
+
+#ifdef __linux__
+#include <sys/time.h>
+// IWYU pragma: no_include <type_traits>
+
+#include <ctime>
+#endif
+
+namespace ruy {
+
+using InternalDefaultClock = std::chrono::steady_clock;
+
+using TimePoint = InternalDefaultClock::time_point;
+using Duration = InternalDefaultClock::duration;
+
+template <typename RepresentationType>
+Duration DurationFromSeconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType>(representation));
+}
+
+template <typename RepresentationType>
+Duration DurationFromMilliseconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType, std::milli>(representation));
+}
+
+template <typename RepresentationType>
+Duration DurationFromNanoseconds(RepresentationType representation) {
+ return std::chrono::duration_cast<Duration>(
+ std::chrono::duration<RepresentationType, std::nano>(representation));
+}
+
+inline float ToFloatSeconds(const Duration& duration) {
+ return std::chrono::duration_cast<std::chrono::duration<float>>(duration)
+ .count();
+}
+
+inline std::int64_t ToInt64Nanoseconds(const Duration& duration) {
+ return std::chrono::duration_cast<
+ std::chrono::duration<std::int64_t, std::nano>>(duration)
+ .count();
+}
+
+inline TimePoint Now() { return InternalDefaultClock::now(); }
+
+inline TimePoint CoarseNow() {
+#ifdef __linux__
+ timespec t;
+ clock_gettime(CLOCK_MONOTONIC_COARSE, &t);
+ return TimePoint(
+ DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec));
+#else
+ return Now();
+#endif
+}
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_
diff --git a/ruy/trace.cc b/ruy/trace.cc
new file mode 100644
index 0000000..1822cdb
--- /dev/null
+++ b/ruy/trace.cc
@@ -0,0 +1,325 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/trace.h"
+
+#include <algorithm>
+#include <cerrno> // IWYU pragma: keep
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+#include <vector>
+
+#include "ruy/check_macros.h"
+#include "ruy/side_pair.h"
+#include "ruy/time.h"
+
+namespace ruy {
+
+#ifdef RUY_TRACE
+
+enum class TraceEvent : std::uint8_t {
+ kNone,
+ kThreadStart,
+ kThreadLoopStart,
+ kThreadEnd,
+ kBlockReserved,
+ kBlockPackedLhs,
+ kBlockPackedRhs,
+ kBlockFinished
+};
+
+struct TraceEntry {
+ TimePoint time_point;
+ TraceEvent event;
+ // ruy-internal thread id i.e. contiguous index into array of threads,
+ // with 0 designating the main thread.
+ std::uint16_t thread_id = 0;
+ // Additional parameters whose meaning depends on the 'event' type.
+ std::uint32_t params[1];
+};
+
+struct Trace {
+ BlockMap block_map;
+ // During recording, to avoid having to use locks or atomics, we let
+ // each thread append to its own specific vector.
+ std::vector<std::vector<TraceEntry>> thread_specific_entries;
+ // Global vector of entries into which we coalesce thread_specific_entries
+ // after recording is finished, when dumping a trace. See
+ // AggregateThreadSpecificEntries.
+ std::vector<TraceEntry> entries;
+ TimePoint time_start;
+ TimePoint time_execute;
+ TimePoint time_end;
+};
+
+namespace {
+
+// Coalesce Trace::thread_specific_entries into Trace::entries.
+void AggregateThreadSpecificEntries(Trace* trace) {
+ RUY_CHECK(trace->entries.empty());
+ for (auto& thread_specific_entries_vector : trace->thread_specific_entries) {
+ for (const TraceEntry& entry : thread_specific_entries_vector) {
+ trace->entries.push_back(entry);
+ }
+ thread_specific_entries_vector.clear();
+ }
+}
+
+// Sort Trace::entries by ascending time. In case of equal timepoints,
+// sort by some semi-arbitrary ordering of event types.
+void Sort(Trace* trace) {
+ std::sort(std::begin(trace->entries), std::end(trace->entries),
+ [](const TraceEntry& a, const TraceEntry& b) -> bool {
+ return a.time_point < b.time_point ||
+ (a.time_point == b.time_point &&
+ static_cast<int>(a.event) < static_cast<int>(b.event));
+ });
+}
+
+// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have
+// already been called on it.
+//
+// On some architectures long long ints are not same as std::int64_t, and
+// time is printed as %lld, so static_casts are necessary.
+void Dump(const Trace& trace) {
+ const char* trace_filename = getenv("RUY_TRACE_FILE");
+ FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
+ if (!trace_file) {
+ fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
+ errno);
+ RUY_CHECK(false);
+ }
+ fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count);
+ fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]);
+ fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]);
+ fprintf(trace_file, "Execute: %lld\n",
+ static_cast<long long int>(
+ ToInt64Nanoseconds(trace.time_execute - trace.time_start)));
+ for (const TraceEntry& entry : trace.entries) {
+ long long int time = static_cast<long long int>(
+ ToInt64Nanoseconds(entry.time_point - trace.time_start));
+ switch (entry.event) {
+ case TraceEvent::kThreadStart:
+ fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id);
+ break;
+ case TraceEvent::kThreadLoopStart:
+ fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time,
+ entry.thread_id);
+ break;
+ case TraceEvent::kThreadEnd:
+ fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id);
+ break;
+ case TraceEvent::kBlockReserved: {
+ std::uint32_t block_id = entry.params[0];
+ SidePair<int> block;
+ GetBlockByIndex(trace.block_map, block_id, &block);
+ SidePair<int> start, end;
+ GetBlockMatrixCoords(trace.block_map, block, &start, &end);
+ fprintf(trace_file,
+ "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time,
+ entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs],
+ start[Side::kLhs], start[Side::kRhs], end[Side::kLhs],
+ end[Side::kRhs]);
+ break;
+ }
+ case TraceEvent::kBlockPackedLhs: {
+ std::uint32_t block = entry.params[0];
+ int start, end;
+ GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end);
+ fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time,
+ entry.thread_id, block, start, end);
+ break;
+ }
+ case TraceEvent::kBlockPackedRhs: {
+ std::uint32_t block = entry.params[0];
+ int start, end;
+ GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end);
+ fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time,
+ entry.thread_id, block, start, end);
+ break;
+ }
+ case TraceEvent::kBlockFinished: {
+ std::uint32_t block_id = entry.params[0];
+ SidePair<int> block;
+ GetBlockByIndex(trace.block_map, block_id, &block);
+ fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time,
+ entry.thread_id, block_id, block[Side::kLhs],
+ block[Side::kRhs]);
+ break;
+ }
+ default:
+ RUY_CHECK(false);
+ }
+ }
+ fprintf(trace_file, "End: %lld\n",
+ static_cast<long long int>(
+ ToInt64Nanoseconds(trace.time_end - trace.time_start)));
+ if (trace_filename) {
+ fclose(trace_file);
+ }
+}
+
+} // anonymous namespace
+
+// Get a Trace object to record to, or null of tracing is not enabled.
+Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) {
+ if (!tracing->initialized) {
+ tracing->initialized = true;
+ tracing->enabled = getenv("RUY_TRACE");
+ if (!tracing->enabled) {
+ return nullptr;
+ }
+ if (getenv("RUY_TRACE_FILTER_ROWS")) {
+ tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS"));
+ }
+ if (getenv("RUY_TRACE_FILTER_DEPTH")) {
+ tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH"));
+ }
+ if (getenv("RUY_TRACE_FILTER_COLS")) {
+ tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS"));
+ }
+ }
+ if (!tracing->enabled) {
+ return nullptr;
+ }
+ if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) {
+ return nullptr;
+ }
+ if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) {
+ return nullptr;
+ }
+ if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) {
+ return nullptr;
+ }
+ // Delete any existing trace.
+ delete tracing->trace;
+ // Create a new one.
+ tracing->trace = new Trace;
+ return tracing->trace;
+}
+
+// The trace recorded on a context is finalized and dumped by
+// this TracingContext destructor.
+//
+// The idea of dumping on context destructor is that typically one wants to
+// run many matrix multiplications, e.g. to hit a steady state in terms of
+// performance characteristics, but only trace the last repetition of the
+// workload, when that steady state was attained.
+TracingContext::~TracingContext() {
+ if (trace) {
+ AggregateThreadSpecificEntries(trace);
+ Sort(trace);
+ Dump(*trace);
+ }
+ delete trace;
+}
+
+void TraceRecordStart(Trace* trace) {
+ if (trace) {
+ trace->time_start = Now();
+ }
+}
+
+void TraceRecordExecute(const BlockMap& block_map, Trace* trace) {
+ if (trace) {
+ trace->time_execute = Now();
+ trace->block_map = block_map;
+ trace->thread_specific_entries.resize(block_map.thread_count);
+ for (int thread = 0; thread < block_map.thread_count; thread++) {
+ trace->thread_specific_entries[thread].clear();
+ // Reserve some large size to avoid frequent heap allocations
+ // affecting the recorded timings.
+ trace->thread_specific_entries[thread].reserve(16384);
+ }
+ }
+}
+
+void TraceRecordEnd(Trace* trace) {
+ if (trace) {
+ trace->time_end = Now();
+ }
+}
+
+void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadStart;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadLoopStart;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = TraceEvent::kBlockReserved;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+ Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs
+ : TraceEvent::kBlockPackedRhs;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = TraceEvent::kBlockFinished;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) {
+ if (trace) {
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadEnd;
+ entry.time_point = Now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
+ }
+}
+
+#endif
+
+} // namespace ruy
diff --git a/ruy/trace.h b/ruy/trace.h
new file mode 100644
index 0000000..d2cc51d
--- /dev/null
+++ b/ruy/trace.h
@@ -0,0 +1,73 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_
+
+#include <cstdint>
+
+#include "ruy/block_map.h"
+#include "ruy/side_pair.h"
+
+namespace ruy {
+
+struct Trace;
+
+#ifdef RUY_TRACE
+
+struct TracingContext {
+ bool initialized = false;
+ bool enabled = false;
+ int filter_shape_rows = 0;
+ int filter_shape_cols = 0;
+ int filter_shape_depth = 0;
+ Trace* trace = nullptr;
+ ~TracingContext();
+};
+
+Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols);
+void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace);
+void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace);
+void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace);
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+ Trace* trace);
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace);
+void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace);
+void TraceRecordStart(Trace* trace);
+void TraceRecordExecute(const BlockMap& block_map, Trace* trace);
+void TraceRecordEnd(Trace* trace);
+
+#else
+
+struct TracingContext {};
+
+inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; }
+inline void TraceRecordThreadStart(std::uint32_t, Trace*) {}
+inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {}
+inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {}
+inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {}
+inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {}
+inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {}
+inline void TraceRecordStart(Trace*) {}
+inline void TraceRecordExecute(const BlockMap&, Trace*) {}
+inline void TraceRecordEnd(Trace*) {}
+
+#endif
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
new file mode 100644
index 0000000..a3ba46a
--- /dev/null
+++ b/ruy/trmul.cc
@@ -0,0 +1,401 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/trmul.h"
+
+#include <atomic>
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#include <vector>
+
+#include "ruy/allocator.h"
+#include "ruy/block_map.h"
+#include "ruy/check_macros.h"
+#include "ruy/common.h"
+#include "ruy/internal_matrix.h"
+#include "ruy/matrix.h"
+#include "ruy/opt_set.h"
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/side_pair.h"
+#include "ruy/size_util.h"
+#include "ruy/spec.h"
+#include "ruy/thread_pool.h"
+#include "ruy/trace.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+namespace {
+
+enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished };
+
+struct TrMulTask final : Task {
+ TrMulTask(TrMulParams* params_, const BlockMap& block_map_,
+ std::atomic<int>* atomic_block_id_, int thread_id_,
+ bool need_atomics_,
+ SidePair<std::atomic<PackingStatus>*> packing_status_,
+ TuningResolver* tuning_resolver_, Allocator* local_allocator_,
+ Trace* trace_)
+ : params(params_),
+ block_map(block_map_),
+ atomic_block_id(atomic_block_id_),
+ thread_id(thread_id_),
+ need_atomics(need_atomics_),
+ packing_status(packing_status_),
+ tuning_resolver(tuning_resolver_),
+ local_allocator(local_allocator_),
+ trace(trace_),
+ local_packed{nullptr, nullptr} {}
+
+ void Run() override {
+ TraceRecordThreadStart(thread_id, trace);
+
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ const int size = NumBlocksPerSide(side, block_map);
+ local_allocator->Allocate(size, &local_packed[side]);
+ memset(local_packed[side], 0, size * sizeof(bool));
+ }
+ }
+
+ const int num_blocks = NumBlocks(block_map);
+
+ const Tuning tuning = tuning_resolver->Resolve();
+
+ TraceRecordThreadLoopStart(thread_id, trace);
+
+ SidePair<int> block;
+ SidePair<int> start;
+ SidePair<int> end;
+
+ // Each thread starts by initially reserving the block whose id
+ // is the thread id.
+ int block_id = thread_id;
+ TraceRecordBlockReserved(thread_id, block_id, trace);
+
+ while (block_id < num_blocks) {
+ // Reserve the next block to handle. In order to hide the latency
+ // (typically comparable to an access to the level of data cache that
+ // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019)
+ // of this atomic operation, we structure this code so as to avoid
+ // immediately depending on the `next_n` result.
+ const int next_block_id =
+ atomic_block_id->fetch_add(1, std::memory_order_relaxed);
+ TraceRecordBlockReserved(thread_id, next_block_id, trace);
+ // Get coordinates of the current block to handle, in "block space".
+ GetBlockByIndex(block_map, block_id, &block);
+ // Get coordinates of the current block to handle, in matrix space.
+ GetBlockMatrixCoords(block_map, block, &start, &end);
+ // Maybe pack the current LHS/RHS block, if not already packed.
+ EnsurePacked(block, start, end, tuning);
+ // Actually do matrix multiplication work
+ params->RunKernel(tuning, start, end);
+ TraceRecordBlockFinished(thread_id, block_id, trace);
+ // Move on to the next block as obtained by the atomic increment
+ // at the start of this while loop iteration.
+ block_id = next_block_id;
+ }
+
+ local_allocator->FreeAll();
+
+ TraceRecordThreadEnd(thread_id, trace);
+ }
+
+ private:
+ // Tries to pack a block, without blocking.
+ // If the block was already packed, returns true.
+ // If the block was not started packing, packs it and returns true.
+ // If the block was being packed by another thread, returns false.
+ bool TryPack(Side side, int block, int start, int end, Tuning tuning) {
+ if (params->is_prepacked[side]) {
+ return true;
+ }
+ if (!local_packed[side][block]) {
+ if (need_atomics) {
+ // Explanation of this compare_exchange_strong operation:
+ // This atomically performs all of the following:
+ // 1. Read `status` with "acquire" memory order.
+ // * That this read uses "acquire" is because both memory orders
+ // specified have "acquire" as their read-component.
+ // 2. Compare (bitwise) with `exchanged_status`.
+ // 3. If equal, stores the value kInProgress to `status` with "release"
+ // memory order, and returns true, so we take this 'if' branch.
+ // * That this store uses "release" is because of the _rel part in
+ // memory_order_acq_rel passed as the first memory order argument.
+ // 4. If not equal, stores the loaded value of `status` to
+ // `exchanged_status` with "relaxed" semantics, and returns false,
+ // so we take the 'else' branch.
+ // * That this store uses "relaxed" is because the second memory
+ // order argument, memory_order_acquire, implies no particular
+ // store semantics. "relaxed" is acceptable here because this
+ // stores to a local stack variable.
+ //
+ // Rationale for compare_exchange_strong as opposed to
+ // compare_exchange_weak:
+ // The spurious-failure case with compare_exchange_weak will actually
+ // happen a lot here, because the atomic 'status' bytes are stored
+ // contiguously in arrays and neighboring values will be accessed
+ // by multiple threads concurrently. On a typical ARM CPU, an exclusives
+ // reservation granule is 64 bytes, so a lot of false-sharing may
+ // happen. Using compare_exchange_weak would thus result in often having
+ // TryPack return 'false' when it could instead have done the packing
+ // work and returned 'true'. Heuristically, that is not a good thing.
+ // Moreover, this changes the TryPack contract, loosening it and making
+ // it harder for the caller to reason about. Finally, the overhead of
+ // atomic operations is mitigated by the enclosing check on
+ // local_packed, so maybe the overhead of compare_exchange_strong isn't
+ // such a problem. But we don't really know for sure, that would be
+ // interesting to experiment more with.
+ PackingStatus exchanged_status = PackingStatus::kNotStarted;
+ std::atomic<PackingStatus>& status = packing_status[side][block];
+ if (status.compare_exchange_strong(
+ exchanged_status, PackingStatus::kInProgress,
+ std::memory_order_acq_rel, std::memory_order_acquire)) {
+ // In this branch, the status was kNotStarted and we just atomically
+ // changed it to kInProgress as we are about to handle the packing
+ // ourselves.
+ params->RunPack(side, tuning, start, end);
+ TraceRecordBlockPacked(thread_id, side, block, trace);
+ status.store(PackingStatus::kFinished, std::memory_order_release);
+ } else if (exchanged_status == PackingStatus::kInProgress) {
+ // Another thread is currently packing this block.
+ return false;
+ }
+ RUY_DCHECK(status.load(std::memory_order_acquire) ==
+ PackingStatus::kFinished);
+ } else {
+ // Single-threaded case: no need for expensive atomics, local_packed
+ // is the truth already.
+ params->RunPack(side, tuning, start, end);
+ TraceRecordBlockPacked(thread_id, side, block, trace);
+ }
+ local_packed[side][block] = true;
+ }
+ return true;
+ }
+
+ // Ensures that both the LHS and RHS blocks required by the specified block
+ // are packed. In the event that they are already being packed on another
+ // threads, this function may perform the packing of some other block while
+ // waiting for that other thread to finish packing the requested block.
+ void EnsurePacked(const SidePair<int>& block, const SidePair<int>& start,
+ const SidePair<int>& end, Tuning tuning) {
+#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD)
+ SidePair<int> next_runahead_block{block[Side::kLhs] + 1,
+ block[Side::kRhs] + 1};
+ Side next_runahead_side = Side::kLhs;
+#endif
+ while (true) {
+ bool both_sides_packed = true;
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ both_sides_packed &=
+ TryPack(side, block[side], start[side], end[side], tuning);
+ }
+ if (both_sides_packed) {
+ break;
+ }
+#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD)
+ const Side runahead_side = next_runahead_side;
+ const int runahead_block = next_runahead_block[runahead_side];
+ next_runahead_side =
+ next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs;
+ if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) {
+ continue;
+ }
+ int runahead_block_start, runahead_block_end;
+ GetBlockMatrixCoords(runahead_side, block_map, runahead_block,
+ &runahead_block_start, &runahead_block_end);
+ TryPack(runahead_side, runahead_block, runahead_block_start,
+ runahead_block_end, tuning);
+ next_runahead_block[runahead_side] = runahead_block + 1;
+#endif
+ }
+ }
+
+ TrMulParams* params;
+ const BlockMap& block_map;
+ std::atomic<int>* atomic_block_id;
+ int thread_id;
+ bool need_atomics;
+ SidePair<std::atomic<PackingStatus>*> packing_status;
+ TuningResolver* tuning_resolver;
+ Allocator* local_allocator;
+ Trace* trace;
+
+ // Local indicators of packedness to avoid the overhead of atomic ops.
+ SidePair<bool*> local_packed;
+};
+
+void AllocatePMatrix(Allocator* allocator, PMatrix* packed) {
+ packed->data = allocator->AllocateBytes(DataSize(*packed));
+ packed->sums = allocator->AllocateBytes(SumsSize(*packed));
+}
+
+int GetThreadCount(Context* context, int rows, int cols, int depth) {
+#if RUY_PLATFORM(EMSCRIPTEN)
+ // b/139927184, std::thread constructor raises exception
+ return 1;
+#endif
+ // Empirically determined rule for reasonable number of
+ // threads to use. This is proportional to the number of arithmetic ops
+ // in this Mul (product of the 3 sizes).
+ static constexpr int kDivisorLog2 = 15;
+ const int guess_log2 = std::max(
+ 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2);
+ return std::min(1 << guess_log2, context->max_num_threads);
+}
+
+LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols,
+ int depth, int lhs_scalar_size,
+ int rhs_scalar_size, int local_data_cache_size,
+ int shared_data_cache_size) {
+ if (tentative_thread_count == 1) {
+ const BlockMapTraversalOrder traversal_order =
+ GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size,
+ local_data_cache_size, shared_data_cache_size);
+ // If we are in the GEMV case or the block_map would be using linear
+ // traversal anyway, use the simple loop.
+ if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) {
+ return LoopStructure::kSimple;
+ }
+ }
+ return LoopStructure::kGeneral;
+}
+
+} // namespace
+
+void TrMul(TrMulParams* params, Context* context) {
+ profiler::ScopeLabel label(
+ "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))",
+ static_cast<int>(params->path), context->max_num_threads,
+ params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]);
+
+ PMatrix& packed_lhs = params->packed[Side::kLhs];
+ PMatrix& packed_rhs = params->packed[Side::kRhs];
+ DMatrix& lhs = params->src[Side::kLhs];
+ DMatrix& rhs = params->src[Side::kRhs];
+
+ const int rows = lhs.layout.cols;
+ const int cols = rhs.layout.cols;
+ const int depth = lhs.layout.rows;
+
+ const int tentative_thread_count = GetThreadCount(context, rows, cols, depth);
+ const auto loop_structure = GetLoopStructure(
+ tentative_thread_count, rows, cols, depth, lhs.data_type.size,
+ rhs.data_type.size, params->local_data_cache_size,
+ params->shared_data_cache_size);
+ Allocator* allocator = context->GetMainAllocator();
+
+ // Allocate packed matrices
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ AllocatePMatrix(allocator, &params->packed[side]);
+ }
+ }
+
+ // Case of running this TrMul as a simple loop.
+ // This is a good place to start reading this function: all the rest
+ // of this function is just an optimized, but functionally equivalent,
+ // version of that.
+ if (loop_structure == LoopStructure::kSimple) {
+ profiler::ScopeLabel label_simple("TrMulImpl, simple loop");
+ Tuning tuning = context->GetMainThreadTuning();
+
+ const SidePair<int> origin{0, 0};
+ const SidePair<int> rounded_dims{packed_lhs.layout.cols,
+ packed_rhs.layout.cols};
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ params->RunPack(side, tuning, origin[side], rounded_dims[side]);
+ }
+ }
+ params->RunKernel(tuning, origin, rounded_dims);
+
+ allocator->FreeAll();
+ return;
+ }
+
+ profiler::ScopeLabel label_general("TrMulImpl, general case");
+
+ auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols);
+ TraceRecordStart(trace);
+
+ // Initialize block map.
+ BlockMap block_map;
+ MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth,
+ packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols,
+ packed_lhs.data_type.size, packed_rhs.data_type.size,
+ tentative_thread_count, params->path,
+ params->local_data_cache_size, params->shared_data_cache_size,
+ &block_map);
+
+ // Initialize per-thread state.
+ const int thread_count = block_map.thread_count;
+ const bool need_atomics = thread_count > 1;
+ context->EnsureNPerThreadStates(thread_count);
+ for (auto& per_thread_state : context->per_thread_states) {
+ per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning);
+ }
+
+ // In the need_atomics case, allocate and initialize atomic values tracking
+ // the packing status of blocks.
+ SidePair<std::atomic<PackingStatus>*> packing_status{nullptr, nullptr};
+ if (need_atomics) {
+ for (Side side : {Side::kLhs, Side::kRhs}) {
+ if (!params->is_prepacked[side]) {
+ const int size = NumBlocksPerSide(side, block_map);
+ allocator->Allocate(size, &packing_status[side]);
+ for (int i = 0; i < size; i++) {
+ packing_status[side][i].store(PackingStatus::kNotStarted,
+ std::memory_order_relaxed);
+ }
+ }
+ }
+ }
+
+ // Create the atomic block id, allocate it using Allocator so that
+ // we get the alignment ensuring that it sits alone in its exclusives
+ // reservation granule.
+ std::atomic<int>* atomic_block_id;
+ allocator->Allocate(1, &atomic_block_id);
+
+ // Create task objects.
+ TrMulTask* tasks;
+ allocator->Allocate(thread_count, &tasks);
+
+ atomic_block_id->store(thread_count);
+
+ for (int i = 0; i < thread_count; i++) {
+ new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i,
+ need_atomics, packing_status,
+ &context->per_thread_states[i]->tuning_resolver,
+ &context->per_thread_states[i]->allocator, trace);
+ }
+
+ // Do the computation.
+ TraceRecordExecute(block_map, trace);
+ context->workers_pool.Execute(thread_count, tasks);
+
+ // Finish up.
+ for (int i = 0; i < thread_count; i++) {
+ tasks[i].~TrMulTask();
+ }
+
+ allocator->FreeAll();
+ TraceRecordEnd(trace);
+}
+
+} // namespace ruy
diff --git a/ruy/trmul.h b/ruy/trmul.h
new file mode 100644
index 0000000..f50bb0c
--- /dev/null
+++ b/ruy/trmul.h
@@ -0,0 +1,38 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// As a matrix multiplication library, Ruy offers a Mul entry point, performing
+// matrix multiplication. For implementation purposes, it is much nicer to
+// be dealing with the transpose-and-multiply operation, doing
+// Destination = Transpose(LHS) * RHS
+// Indeed, the latter is performing dot-products between the *columns* of LHS
+// and the columns of RHS, whereas a plain matrix multiplication is performing
+// dot-products between the *rows* of LHS and the columns of RHS.
+// That is why TrMul is nicer to implement, allowing for a more symmetric
+// treatment of LHS and RHS.
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_
+
+#include "ruy/context.h"
+#include "ruy/trmul_params.h"
+
+namespace ruy {
+
+void TrMul(TrMulParams* params, Context* context);
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_
diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h
new file mode 100644
index 0000000..47537b7
--- /dev/null
+++ b/ruy/trmul_params.h
@@ -0,0 +1,67 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_
+
+#include "ruy/internal_matrix.h"
+#include "ruy/side_pair.h"
+#include "ruy/tune.h"
+
+namespace ruy {
+
+using RunKernelFn = void(Tuning, const SidePair<PMatrix>&, void*,
+ const SidePair<int>&, const SidePair<int>&, DMatrix*);
+
+using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int);
+
+// Type-erased data needed for implementing TrMul.
+struct TrMulParams {
+ TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {}
+ // Helper functions for invoking the function pointers.
+ void RunPack(Side side, Tuning tuning, int start, int end) {
+ run_pack[side](tuning, src[side], &packed[side], start, end);
+ }
+ void RunKernel(Tuning tuning, const SidePair<int>& start,
+ const SidePair<int>& end) {
+ run_kernel(tuning, packed, spec, start, end, &dst);
+ }
+
+ // path id, can be useful info for some fine-tuning, e.g. to guess reasonable
+ // cache sizes when not runtime-detectable.
+ Path path;
+
+ // See Spec::local_data_cache_size().
+ int local_data_cache_size = 0;
+ // See Spec::shared_data_cache_size().
+ int shared_data_cache_size = 0;
+
+ // Function pointers to type-erased entry points for kernels and packers.
+ SidePair<RunPackFn*> run_pack;
+ RunKernelFn* run_kernel = nullptr;
+
+ // Matrices and packed matrices.
+ SidePair<DMatrix> src;
+ DMatrix dst;
+ SidePair<PMatrix> packed;
+ SidePair<bool> is_prepacked;
+
+ // Type-erased Spec.
+ void* spec = nullptr;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_
diff --git a/ruy/tune.cc b/ruy/tune.cc
new file mode 100644
index 0000000..a89242f
--- /dev/null
+++ b/ruy/tune.cc
@@ -0,0 +1,161 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/tune.h"
+
+#include <algorithm>
+#include <cstdint>
+
+namespace ruy {
+
+#ifdef RUY_IMPLEMENT_TUNING
+
+namespace {
+
+void PoorlyOrderedKernel(int iters) {
+ asm volatile(
+ "mov w0, %w[iters]\n"
+ "1:\n"
+ "subs w0, w0, #1\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "bne 1b\n" ::[iters] "r"(iters)
+ : "cc", "x0", "v0", "v1", "v2", "v3");
+}
+
+void NicelyOrderedKernel(int iters) {
+ asm volatile(
+ "mov w0, %w[iters]\n"
+ "1:\n"
+ "subs w0, w0, #1\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "mul v0.4s, v0.4s, v0.4s\n"
+ "mul v1.4s, v1.4s, v1.4s\n"
+ "mul v2.4s, v2.4s, v2.4s\n"
+ "mul v3.4s, v3.4s, v3.4s\n"
+ "bne 1b\n" ::[iters] "r"(iters)
+ : "cc", "x0", "v0", "v1", "v2", "v3");
+}
+
+} // namespace
+
+float TuningResolver::EvalRatio() {
+ // With the current settings, 400 iterations and 4 repeats, this test has
+ // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz.
+ static constexpr int kLoopIters = 400;
+ static constexpr int kRepeats = 4;
+
+ Duration timing_poorly_ordered = Duration::max();
+ Duration timing_nicely_ordered = Duration::max();
+
+ for (int r = 0; r < kRepeats; r++) {
+ TimePoint t0 = Now();
+ PoorlyOrderedKernel(kLoopIters);
+ TimePoint t1 = Now();
+ NicelyOrderedKernel(kLoopIters);
+ TimePoint t2 = Now();
+ timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0);
+ timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1);
+ }
+
+ return ToFloatSeconds(timing_nicely_ordered) /
+ ToFloatSeconds(timing_poorly_ordered);
+}
+
+float TuningResolver::ThresholdRatio() {
+ // Empirically (see :tune_tool) determined threshold to distinguish in-order
+ // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based
+ // on these experimental results, which were obtained with much lower
+ // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we
+ // have:
+ //
+ // CPU core type | in/out of order | observed ratio
+ // --------------+-----------------+-----------------------------------------
+ // Cortex-A53 | in-order | 0.32 -- 0.329
+ // Cortex-A55 | in-order | 0.319 -- 0.325
+ // Cortex-A55r1 | in-order | 0.319 -- 0.325
+ // Cortex-A57 | out-of-order | 0.99 -- 1.01
+ // Cortex-A73 | out-of-order | 0.922 -- 0.927
+ // Cortex-A75 | out-of-order | 0.921 -- 0.93
+ // Cortex-A76 | out-of-order | 1
+ // Kryo (pixel1) | out-of-order | 0.73 -- 0.76
+ //
+ // Thus the allowable range for the threshold is [0.35 .. 0.70].
+ // We pick a value closer to the upper bound because really any out-of-order
+ // CPU should by definition produce a ratio close to 1.
+ return 0.65f;
+}
+
+Tuning TuningResolver::ResolveNow() {
+ const bool is_probably_inorder = EvalRatio() < ThresholdRatio();
+ return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder;
+}
+
+#else // not defined RUY_IMPLEMENT_TUNING
+
+float TuningResolver::EvalRatio() { return 0; }
+float TuningResolver::ThresholdRatio() { return 0; }
+
+Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; }
+
+#endif
+
+TuningResolver::TuningResolver()
+ : expiry_duration_(DurationFromMilliseconds(250)) {}
+
+Tuning TuningResolver::Resolve() {
+#ifdef RUY_IMPLEMENT_TUNING
+ if (unresolved_tuning_ != Tuning::kAuto) {
+ return unresolved_tuning_;
+ }
+ TimePoint new_timepoint = CoarseNow();
+ if (last_resolved_tuning_ != Tuning::kAuto &&
+ (new_timepoint - last_resolved_timepoint_) < expiry_duration_) {
+ return last_resolved_tuning_;
+ }
+ last_resolved_timepoint_ = new_timepoint;
+ last_resolved_tuning_ = ResolveNow();
+ return last_resolved_tuning_;
+#else
+ return Tuning::kOutOfOrder;
+#endif
+}
+
+} // namespace ruy
diff --git a/ruy/tune.h b/ruy/tune.h
new file mode 100644
index 0000000..e6a0ee8
--- /dev/null
+++ b/ruy/tune.h
@@ -0,0 +1,163 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Library doing minimal CPU detection to decide what to tune asm code for.
+//
+// # Tuning vs Path
+//
+// Tunings are merely local variations of optimized code paths, that are
+// drop-in replacements for each other --- the input and output data layouts
+// are identical. By contrast, what ruy calls a Path dictates its own
+// data layouts. For example, Path::kNeonDotprod will use different
+// layouts compared to Path::kNeon; but within each, different tunings
+// will share that same layout.
+//
+// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder
+//
+// In practice, each of our asm code paths only needs one bit information to
+// decide on tuning: whether the CPU is out-of-order or in-order.
+// That is because out-of-order CPUs are by definition relatively insensitive
+// to small-scale asm details (which is what "tuning" is about); and for each
+// asm code path, there tends to be one main in-order CPU architecture that
+// we focus our tuning effort on. Examples:
+// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod)
+// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod)
+//
+// Because having tuned code paths is a compromise of efficiency gains
+// versus implementation effort and code size, we are happy to stop at just this
+// single bit of information, OutOfOrder/InOrder, at least in the current CPU
+// landscape. This could change in the future.
+//
+// # Implementation notes and alternatives.
+//
+// The current implementation uses a nano-benchmark, see tune.cc.
+// That is why it's quite expensive, making caching /
+// statefulness necessary (see TuningResolver class comment).
+//
+// An interesting alternative, which was explained to us by Marat Dukhan
+// (maratek@) after this was implemented, would be to use the
+// getcpu(2) system call on Linux. This returns a
+// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder
+// classification given additional information about the CPU. Such
+// additional information could be obtained by the cpuinfo library,
+// https://github.com/pytorch/cpuinfo
+// which obtains this information mainly from parsing /proc/cpuinfo.
+// Pros:
+// * Would remove the need for the relatively expensive nano-benchmark
+// (dozens of microseconds, which have to be reevaluated again several
+// times per second).
+// * Would conceivably be more reliable.
+// Cons:
+// * Linux-specific.
+// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k).
+// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc).
+//
+// We could also have both:
+// * Maybe by trying getcpu first if supported, then falling back to a
+// nano-benchmark.
+// * Maybe using getcpu in conjunction with the nano-benchmark to cache
+// per-CPU-id nano-benchmark results.
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_
+
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/time.h"
+
+// Tuning only implemented on NEON_64 at the moment (see assembly code
+// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect
+// results on in-order-tuned kernels combining ARM and NEON load instructions
+// and NEON `ins` instructions).
+//
+// When tuning is not implemented, we simply always use Tuning::kOutOfOrder.
+#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \
+ !RUY_PLATFORM(APPLE)
+#define RUY_IMPLEMENT_TUNING
+#endif
+
+namespace ruy {
+
+enum class Tuning {
+ // kAuto means please use auto-detection. It's the default in the
+ // user-visible parts (see Context). It's meant to be resolved to an
+ // actual tuning at some point by means of TuningResolver.
+ kAuto,
+ // Target an out-order CPU. Example: ARM Cortex-A75.
+ kOutOfOrder,
+ // Target an in-order CPU. Example: ARM Cortex-A55.
+ kInOrder
+};
+
+// Why a TuningResolver class?
+//
+// Ideally, this Library would offer a single function,
+// Tuning GetCurrentCPUTuning();
+//
+// However, determining information about the current CPU is not necessarily,
+// cheap, so we currently cache that and only invalidate/reevaluate after
+// a fixed amount of time. This need to store state is why this library
+// has to expose a class, TuningResolver, not just a function.
+class TuningResolver {
+ public:
+ TuningResolver();
+
+ // Allows the user to specify an explicit Tuning value, bypassing auto
+ // detection; or to specify Tuning::kAuto, reverting to auto detection.
+ void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; }
+
+ // Get an actual tuning --- that is the function that this class wanted to be.
+ Tuning Resolve();
+
+ private:
+ TuningResolver(const TuningResolver&) = delete;
+
+ // TuningTool is a demo/tool used to tweak the tuning implementation to
+ // specific devices. It needs to access some finer granularity information
+ // than just the Tuning returned by Resolve. Nothing else should need
+ // access to that.
+ friend class TuneTool;
+ // Actually runs a nano-benchmark, producing a real number called 'ratio'
+ // whose meaning is generally opaque / implementation defined. Typically,
+ // this would be the ratio between the latencies of two different
+ // pieces of asm code differing only by the ordering of instructions,
+ // revealing whether the CPU cares about such ordering details.
+ // An implementation may just return a dummy value if it is not based on
+ // such nanobenchmarking / ratio evaluation.
+ float EvalRatio();
+ // Empirically determined threshold on ratio values delineating
+ // out-of-order (ratios closer to 1) from in-order (ratios farther from 1).
+ // An implementation may just return a dummy value if it is not based on
+ // such nanobenchmarking / ratio evaluation.
+ float ThresholdRatio();
+ // Perform the tuning resolution now. That may typically use EvalRatio and
+ // ThresholdRatio, but an implementation may use a different approach instead.
+ Tuning ResolveNow();
+
+ // The tuning as specified by the user, before actual resolution happens
+ // i.e. before querying any specifics of the current CPU.
+ // The default value kAuto means try to auto-detect. Other values mean
+ // bypass auto-detect, use explicit value instead. See SetTuning().
+ Tuning unresolved_tuning_ = Tuning::kAuto;
+ // Cached last resolved tuning.
+ Tuning last_resolved_tuning_ = Tuning::kAuto;
+ // Timepoint of cached last resolved tuning, for invalidation purposes.
+ TimePoint last_resolved_timepoint_;
+ // Cached last resolved tunings that are older than this age are invalid.
+ const Duration expiry_duration_;
+};
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_
diff --git a/ruy/tune_test.cc b/ruy/tune_test.cc
new file mode 100644
index 0000000..ebd86e0
--- /dev/null
+++ b/ruy/tune_test.cc
@@ -0,0 +1,53 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/tune.h"
+
+#include <chrono> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "testing/base/public/gunit.h"
+
+namespace ruy {
+namespace {
+
+TEST(TuneTest, TuneTest) {
+ TuningResolver tuning_resolver;
+ ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto);
+ // 1 second is likely higher than TuningResolver's internal cache expiry,
+ // exercising the logic invalidating earlier tuning resolutions.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto);
+
+ tuning_resolver.SetTuning(Tuning::kAuto);
+
+#ifdef RUY_IMPLEMENT_TUNING
+ for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) {
+ tuning_resolver.SetTuning(tuning);
+ ASSERT_TRUE(tuning_resolver.Resolve() == tuning);
+ // See above comment about 1 second.
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ ASSERT_TRUE(tuning_resolver.Resolve() == tuning);
+ }
+#endif
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/tune_tool.cc b/ruy/tune_tool.cc
new file mode 100644
index 0000000..0b6e4ab
--- /dev/null
+++ b/ruy/tune_tool.cc
@@ -0,0 +1,56 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Self-contained tool used to tune the tune code --- see the
+// threshold ratios used in tune.cc.
+
+#include <chrono> // NOLINT(build/c++11)
+#include <cstdio>
+#include <thread> // NOLINT(build/c++11)
+
+#include "ruy/tune.h"
+
+#ifdef _WIN32
+#define getpid() 0
+#else
+#include <unistd.h>
+#endif
+
+namespace ruy {
+
+class TuneTool {
+ public:
+ static void Query(float* eval, float* threshold) {
+ TuningResolver resolver;
+ *eval = resolver.EvalRatio();
+ *threshold = resolver.ThresholdRatio();
+ }
+};
+
+} // namespace ruy
+
+int main() {
+ // Infinite loop: the user can hit Ctrl-C
+ while (true) {
+ float eval;
+ float threshold;
+ ruy::TuneTool::Query(&eval, &threshold);
+ printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(),
+ eval, eval < threshold ? '<' : '>', threshold,
+ eval < threshold ? "in-order" : "out-of-order");
+ fflush(stdout);
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ }
+}
diff --git a/ruy/wait.cc b/ruy/wait.cc
new file mode 100644
index 0000000..d8156bc
--- /dev/null
+++ b/ruy/wait.cc
@@ -0,0 +1,69 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/wait.h"
+
+#include <chrono> // NOLINT(build/c++11)
+
+namespace ruy {
+
+void Wait(const std::function<bool()>& condition, const Duration& spin_duration,
+ std::condition_variable* condvar, std::mutex* mutex) {
+ // First, trivial case where the `condition` is already true;
+ if (condition()) {
+ return;
+ }
+
+ // Then try busy-waiting.
+ const TimePoint wait_start = Now();
+ while (Now() - wait_start < spin_duration) {
+ if (condition()) {
+ return;
+ }
+ }
+
+ // Finally, do real passive waiting.
+ std::unique_lock<std::mutex> lock(*mutex);
+ condvar->wait(lock, condition);
+}
+
+void Wait(const std::function<bool()>& condition,
+ std::condition_variable* condvar, std::mutex* mutex) {
+ // This value was empirically derived with some microbenchmark, we don't have
+ // high confidence in it.
+ //
+ // TODO(b/135595069): make this value configurable at runtime.
+ // I almost wanted to file another bug to ask for experimenting in a more
+ // principled way to tune this value better, but this would have to be tuned
+ // on real end-to-end applications and we'd expect different applications to
+ // require different tunings. So the more important point is the need for
+ // this to be controllable by the application.
+ //
+ // That this value means that we may be sleeping substantially longer
+ // than a scheduler timeslice's duration is not necessarily surprising. The
+ // idea is to pick up quickly new work after having finished the previous
+ // workload. When it's new work within the same GEMM as the previous work, the
+ // time interval that we might be busy-waiting is very small, so for that
+ // purpose it would be more than enough to sleep for 1 ms.
+ // That is all what we would observe on a GEMM benchmark. However, in a real
+ // application, after having finished a GEMM, we might do unrelated work for
+ // a little while, then start on a new GEMM. In that case the wait interval
+ // may be a little longer. There may also not be another GEMM for a long time,
+ // in which case we'll end up passively waiting below.
+ const Duration spin_duration = DurationFromMilliseconds(2);
+ Wait(condition, spin_duration, condvar, mutex);
+}
+
+} // namespace ruy
diff --git a/ruy/wait.h b/ruy/wait.h
new file mode 100644
index 0000000..900ec8d
--- /dev/null
+++ b/ruy/wait.h
@@ -0,0 +1,73 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_
+
+#include <condition_variable> // NOLINT(build/c++11)
+#include <functional>
+#include <mutex> // NOLINT(build/c++11)
+
+#include "ruy/time.h"
+
+namespace ruy {
+
+// Waits until some evaluation of `condition` has returned true.
+//
+// There is no guarantee that calling `condition` again after this function
+// has returned would still return true. The only
+// contract is that at some point during the execution of that function,
+// `condition` has returned true.
+//
+// First does some spin-waiting for the specified `spin_duration`,
+// then falls back to passive waiting for the given condvar, guarded
+// by the given mutex. At this point it will try to acquire the mutex lock,
+// around the waiting on the condition variable.
+// Therefore, this function expects that the calling thread hasn't already
+// locked the mutex before calling it.
+// This function will always release the mutex lock before returning.
+//
+// The idea of doing some initial spin-waiting is to help get
+// better and more consistent multithreading benefits for small GEMM sizes.
+// Spin-waiting help ensuring that if we need to wake up soon after having
+// started waiting, then we can wake up quickly (as opposed to, say,
+// having to wait to be scheduled again by the OS). On the other hand,
+// we must still eventually revert to passive waiting for longer waits
+// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
+// so as to avoid permanently spinning.
+//
+// In situations where other threads might have more useful things to do with
+// these CPU cores than our spin-waiting, it may be best to reduce the value
+// of `spin_duration`. Setting it to zero disables the spin-waiting entirely.
+//
+// There is a risk that the std::function used here might use a heap allocation
+// to store its context. The expected usage pattern is that these functions'
+// contexts will consist of a single pointer value (typically capturing only
+// [this]), and that in this case the std::function implementation will use
+// inline storage, avoiding a heap allocation. However, we can't effectively
+// guard that assumption, and that's not a big concern anyway because the
+// latency of a small heap allocation is probably low compared to the intrinsic
+// latency of what this Wait function does.
+void Wait(const std::function<bool()>& condition, const Duration& spin_duration,
+ std::condition_variable* condvar, std::mutex* mutex);
+
+// Convenience overload using a default `spin_duration`.
+// TODO(benoitjacob): let this be controlled from the ruy API.
+void Wait(const std::function<bool()>& condition,
+ std::condition_variable* condvar, std::mutex* mutex);
+
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_
diff --git a/ruy/wait_test.cc b/ruy/wait_test.cc
new file mode 100644
index 0000000..f0548f9
--- /dev/null
+++ b/ruy/wait_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/wait.h"
+
+#include <atomic>
+#include <condition_variable> // NOLINT(build/c++11)
+#include <mutex> // NOLINT(build/c++11)
+#include <thread> // NOLINT(build/c++11)
+
+#include "testing/base/public/gunit.h"
+#include "ruy/platform.h"
+
+namespace ruy {
+namespace {
+
+// Thread taking a `value` atomic counter and incrementing it until it equals
+// `end_value`, then notifying the condition variable as long as
+// `value == end_value`. If `end_value` is increased, it will then resume
+// incrementing `value`, etc. Terminates if `end_value == -1`.
+class ThreadCountingUpToValue {
+ public:
+ ThreadCountingUpToValue(const std::atomic<int>& end_value,
+ std::atomic<int>* value,
+ std::condition_variable* condvar, std::mutex* mutex)
+ : end_value_(end_value),
+ value_(value),
+ condvar_(condvar),
+ mutex_(mutex) {}
+ void operator()() {
+ // end_value_==-1 is how the master thread will tell us it's OK to terminate
+ while (end_value_.load() != -1) {
+ // wait until end_value is set to a higher value
+ while (value_->load() == end_value_.load()) {
+ }
+ // increment value as long as it's lower than end_value
+ while (value_->fetch_add(1) < end_value_.load() - 1) {
+ }
+ // when value has reached end_value, notify the master thread.
+ while (value_->load() == end_value_.load()) {
+ std::lock_guard<std::mutex> lock(*mutex_);
+ condvar_->notify_all();
+ }
+ }
+ }
+
+ private:
+ const std::atomic<int>& end_value_;
+ std::atomic<int>* value_;
+ std::condition_variable* condvar_;
+ std::mutex* mutex_;
+};
+
+void WaitTest(const Duration& spin_duration, const Duration& delay) {
+#if RUY_PLATFORM(EMSCRIPTEN)
+ // b/139927184, std::thread constructor raises exception
+ return;
+#endif
+ std::condition_variable condvar;
+ std::mutex mutex;
+ std::atomic<int> value(0);
+ std::atomic<int> end_value(0);
+ ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex);
+ std::thread thread(thread_callable);
+ std::this_thread::sleep_for(delay);
+ for (int i = 1; i < 10; i++) {
+ end_value.store(1000 * i);
+ const auto& condition = [&value, &end_value]() {
+ return value.load() == end_value.load();
+ };
+ ruy::Wait(condition, spin_duration, &condvar, &mutex);
+ EXPECT_EQ(value.load(), end_value.load());
+ }
+ end_value.store(-1);
+ thread.join();
+}
+
+TEST(WaitTest, WaitTestNoSpin) {
+ WaitTest(DurationFromSeconds(0), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneMicrosecond) {
+ WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneMillisecond) {
+ WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0));
+}
+
+TEST(WaitTest, WaitTestSpinOneSecond) {
+ WaitTest(DurationFromSeconds(1), DurationFromSeconds(0));
+}
+
+// Testcase to consistently reproduce the hang in b/139062384.
+TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) {
+ WaitTest(DurationFromSeconds(0), DurationFromSeconds(1));
+}
+
+} // namespace
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}